aboutsummaryrefslogtreecommitdiff
path: root/connection_handler.go
diff options
context:
space:
mode:
Diffstat (limited to 'connection_handler.go')
-rw-r--r--connection_handler.go261
1 files changed, 147 insertions, 114 deletions
diff --git a/connection_handler.go b/connection_handler.go
index 2e2fa84..4fc2a95 100644
--- a/connection_handler.go
+++ b/connection_handler.go
@@ -1,95 +1,161 @@
package main
import (
- "context"
"encoding/binary"
- "fmt"
"github.com/flier/gohs/hyperscan"
"github.com/google/gopacket"
"github.com/google/gopacket/tcpassembly"
- "log"
+ log "github.com/sirupsen/logrus"
"sync"
"time"
)
+const initialConnectionsCapacity = 1024
+const initialScannersCapacity = 1024
+
type BiDirectionalStreamFactory struct {
- storage Storage
- serverIp gopacket.Endpoint
- connections map[StreamKey]ConnectionHandler
- mConnections sync.Mutex
- patterns hyperscan.StreamDatabase
- mPatterns sync.Mutex
- scratches []*hyperscan.Scratch
+ storage Storage
+ serverIp gopacket.Endpoint
+ connections map[StreamFlow]ConnectionHandler
+ mConnections sync.Mutex
+ rulesManager *RulesManager
+ rulesDatabase RulesDatabase
+ mRulesDatabase sync.Mutex
+ scanners []Scanner
}
-type StreamKey [4]gopacket.Endpoint
+type StreamFlow [4]gopacket.Endpoint
+
+type Scanner struct {
+ scratch *hyperscan.Scratch
+ version RowID
+}
type ConnectionHandler interface {
Complete(handler *StreamHandler)
Storage() Storage
- Context() context.Context
- Patterns() hyperscan.StreamDatabase
+ PatternsDatabase() hyperscan.StreamDatabase
}
type connectionHandlerImpl struct {
- storage Storage
- net, transport gopacket.Flow
- initiator StreamKey
- connectionKey string
+ factory *BiDirectionalStreamFactory
+ connectionFlow StreamFlow
mComplete sync.Mutex
otherStream *StreamHandler
- context context.Context
- patterns hyperscan.StreamDatabase
+}
+
+func NewBiDirectionalStreamFactory(storage Storage, serverIP gopacket.Endpoint,
+ rulesManager *RulesManager) *BiDirectionalStreamFactory {
+
+ factory := &BiDirectionalStreamFactory{
+ storage: storage,
+ serverIp: serverIP,
+ connections: make(map[StreamFlow]ConnectionHandler, initialConnectionsCapacity),
+ mConnections: sync.Mutex{},
+ rulesManager: rulesManager,
+ mRulesDatabase: sync.Mutex{},
+ scanners: make([]Scanner, 0, initialScannersCapacity),
+ }
+
+ go factory.updateRulesDatabaseService()
+ return factory
+}
+
+func (factory *BiDirectionalStreamFactory) updateRulesDatabaseService() {
+ for {
+ select {
+ case rulesDatabase, ok := <-factory.rulesManager.databaseUpdated:
+ if !ok {
+ return
+ }
+ factory.mRulesDatabase.Lock()
+ scanners := factory.scanners
+ factory.scanners = factory.scanners[:0]
+
+ for _, s := range scanners {
+ err := s.scratch.Realloc(rulesDatabase.database)
+ if err != nil {
+ log.WithError(err).Error("failed to realloc an existing scanner")
+ } else {
+ s.version = rulesDatabase.version
+ factory.scanners = append(factory.scanners, s)
+ }
+ }
+
+ factory.rulesDatabase = rulesDatabase
+ factory.mRulesDatabase.Unlock()
+ }
+ }
+}
+
+func (factory *BiDirectionalStreamFactory) takeScanner() Scanner {
+ factory.mRulesDatabase.Lock()
+ defer factory.mRulesDatabase.Unlock()
+
+ if len(factory.scanners) == 0 {
+ scratch, err := hyperscan.NewScratch(factory.rulesDatabase.database)
+ if err != nil {
+ log.WithError(err).Fatal("failed to alloc a new scratch")
+ }
+
+ return Scanner{
+ scratch: scratch,
+ version: factory.rulesDatabase.version,
+ }
+ }
+
+ index := len(factory.scanners) - 1
+ scanner := factory.scanners[index]
+ factory.scanners = factory.scanners[:index]
+
+ return scanner
+}
+
+func (factory *BiDirectionalStreamFactory) releaseScanner(scanner Scanner) {
+ factory.mRulesDatabase.Lock()
+ defer factory.mRulesDatabase.Unlock()
+
+ if scanner.version != factory.rulesDatabase.version {
+ err := scanner.scratch.Realloc(factory.rulesDatabase.database)
+ if err != nil {
+ log.WithError(err).Error("failed to realloc an existing scanner")
+ return
+ }
+ }
+ factory.scanners = append(factory.scanners, scanner)
}
func (factory *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream {
- key := StreamKey{net.Src(), net.Dst(), transport.Src(), transport.Dst()}
- invertedKey := StreamKey{net.Dst(), net.Src(), transport.Dst(), transport.Src()}
+ flow := StreamFlow{net.Src(), net.Dst(), transport.Src(), transport.Dst()}
+ invertedFlow := StreamFlow{net.Dst(), net.Src(), transport.Dst(), transport.Src()}
factory.mConnections.Lock()
- connection, isPresent := factory.connections[invertedKey]
+ connection, isPresent := factory.connections[invertedFlow]
if isPresent {
- delete(factory.connections, invertedKey)
+ delete(factory.connections, invertedFlow)
} else {
- var initiator StreamKey
+ var connectionFlow StreamFlow
if net.Src() == factory.serverIp {
- initiator = invertedKey
+ connectionFlow = invertedFlow
} else {
- initiator = key
+ connectionFlow = flow
}
connection = &connectionHandlerImpl{
- storage: factory.storage,
- net: net,
- transport: transport,
- initiator: initiator,
- mComplete: sync.Mutex{},
- context: context.Background(),
- patterns: factory.patterns,
+ connectionFlow: connectionFlow,
+ mComplete: sync.Mutex{},
+ factory: factory,
}
- factory.connections[key] = connection
+ factory.connections[flow] = connection
}
factory.mConnections.Unlock()
- streamHandler := NewStreamHandler(connection, key, factory.takeScratch())
+ streamHandler := NewStreamHandler(connection, flow, factory.takeScanner())
return &streamHandler
}
-func (factory *BiDirectionalStreamFactory) UpdatePatternsDatabase(database hyperscan.StreamDatabase) {
- factory.mPatterns.Lock()
- factory.patterns = database
-
- for _, s := range factory.scratches {
- err := s.Realloc(database)
- if err != nil {
- fmt.Println("failed to realloc an existing scratch")
- }
- }
-
- factory.mPatterns.Unlock()
-}
-
func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) {
+ ch.factory.releaseScanner(handler.scanner)
ch.mComplete.Lock()
if ch.otherStream == nil {
ch.otherStream = handler
@@ -112,7 +178,7 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) {
}
var client, server *StreamHandler
- if handler.streamKey == ch.initiator {
+ if handler.streamFlow == ch.connectionFlow {
client = handler
server = ch.otherStream
} else {
@@ -120,81 +186,48 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) {
server = handler
}
- ch.generateConnectionKey(startedAt)
-
- _, err := ch.storage.Insert("connections").Context(ch.context).One(OrderedDocument{
- {"_id", ch.connectionKey},
- {"ip_src", ch.initiator[0].String()},
- {"ip_dst", ch.initiator[1].String()},
- {"port_src", ch.initiator[2].String()},
- {"port_dst", ch.initiator[3].String()},
- {"started_at", startedAt},
- {"closed_at", closedAt},
- {"client_bytes", client.streamLength},
- {"server_bytes", server.streamLength},
- {"client_documents", len(client.documentsKeys)},
- {"server_documents", len(server.documentsKeys)},
- {"processed_at", time.Now()},
- })
+ connectionID := ch.Storage().NewCustomRowID(ch.connectionFlow.Hash(), startedAt)
+ connection := Connection{
+ ID: connectionID,
+ SourceIP: ch.connectionFlow[0].String(),
+ DestinationIP: ch.connectionFlow[1].String(),
+ SourcePort: binary.BigEndian.Uint16(ch.connectionFlow[2].Raw()),
+ DestinationPort: binary.BigEndian.Uint16(ch.connectionFlow[3].Raw()),
+ StartedAt: startedAt,
+ ClosedAt: closedAt,
+ ClientPackets: client.packetsCount,
+ ServerPackets: client.packetsCount,
+ ClientBytes: client.streamLength,
+ ServerBytes: server.streamLength,
+ ClientDocuments: len(client.documentsIDs),
+ ServerDocuments: len(server.documentsIDs),
+ ProcessedAt: time.Now(),
+ }
+ _, err := ch.Storage().Insert(Connections).One(connection)
if err != nil {
- log.Println("error inserting document on collection connections with _id = ", ch.connectionKey)
+ log.WithError(err).WithField("connection", connection).Error("failed to insert a connection")
+ return
}
- streamsIds := append(client.documentsKeys, server.documentsKeys...)
- n, err := ch.storage.Update("connection_streams").
- Context(ch.context).
+ streamsIds := append(client.documentsIDs, server.documentsIDs...)
+ n, err := ch.Storage().Update(ConnectionStreams).
Filter(OrderedDocument{{"_id", UnorderedDocument{"$in": streamsIds}}}).
- Many(UnorderedDocument{"connection_id": ch.connectionKey})
+ Many(UnorderedDocument{"connection_id": connectionID})
if err != nil {
- log.Println("failed to update connection streams", err)
- }
- if int(n) != len(streamsIds) {
- log.Println("failed to update all connections streams")
+ log.WithError(err).WithField("connection", connection).Error("failed to update connection streams")
+ } else if int(n) != len(streamsIds) {
+ log.WithError(err).WithField("connection", connection).Error("failed to update all connections streams")
}
}
func (ch *connectionHandlerImpl) Storage() Storage {
- return ch.storage
-}
-
-func (ch *connectionHandlerImpl) Context() context.Context {
- return ch.context
-}
-
-func (ch *connectionHandlerImpl) Patterns() hyperscan.StreamDatabase {
- return ch.patterns
+ return ch.factory.storage
}
-func (ch *connectionHandlerImpl) generateConnectionKey(firstPacketSeen time.Time) {
- hash := make([]byte, 16)
- binary.BigEndian.PutUint64(hash, uint64(firstPacketSeen.UnixNano()))
- binary.BigEndian.PutUint64(hash[8:], ch.net.FastHash()^ch.transport.FastHash())
-
- ch.connectionKey = fmt.Sprintf("%x", hash)
-}
-
-func (factory *BiDirectionalStreamFactory) takeScratch() *hyperscan.Scratch {
- factory.mPatterns.Lock()
- defer factory.mPatterns.Unlock()
-
- if len(factory.scratches) == 0 {
- scratch, err := hyperscan.NewScratch(factory.patterns)
- if err != nil {
- fmt.Println("failed to alloc a new scratch")
- }
-
- return scratch
- }
-
- index := len(factory.scratches) - 1
- scratch := factory.scratches[index]
- factory.scratches = factory.scratches[:index]
-
- return scratch
+func (ch *connectionHandlerImpl) PatternsDatabase() hyperscan.StreamDatabase {
+ return ch.factory.rulesDatabase.database
}
-func (factory *BiDirectionalStreamFactory) releaseScratch(scratch *hyperscan.Scratch) {
- factory.mPatterns.Lock()
- factory.scratches = append(factory.scratches, scratch)
- factory.mPatterns.Unlock()
+func (sf StreamFlow) Hash() uint64 {
+ return sf[0].FastHash() ^ sf[1].FastHash() ^ sf[2].FastHash() ^ sf[3].FastHash()
}