aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEmiliano Ciavatta2020-04-09 10:37:48 +0000
committerEmiliano Ciavatta2020-04-09 10:37:48 +0000
commit7ca2f30a0eb21e22071f4e6b04a5207fa273d283 (patch)
tree63acb98147ffda7606bdf81abe2894e5f8363bd9
parent0520dab47d61e2c4de246459bf4f5c72d69182d3 (diff)
Refactor connection_handler
-rw-r--r--caronte.go6
-rw-r--r--connection_handler.go261
-rw-r--r--connections.go20
-rw-r--r--rules_manager.go12
-rw-r--r--stream_handler.go51
-rw-r--r--stream_handler_test.go13
6 files changed, 216 insertions, 147 deletions
diff --git a/caronte.go b/caronte.go
index a32b619..a6fa584 100644
--- a/caronte.go
+++ b/caronte.go
@@ -4,7 +4,7 @@ import (
"flag"
"fmt"
"github.com/gin-gonic/gin"
- "log"
+ log "github.com/sirupsen/logrus"
)
func main() {
@@ -21,13 +21,13 @@ func main() {
storage := NewMongoStorage(*mongoHost, *mongoPort, *dbName)
err := storage.Connect(nil)
if err != nil {
- log.Panicln("failed to connect to MongoDB:", err)
+ log.WithError(err).Fatal("failed to connect to MongoDB")
}
router := gin.Default()
ApplicationRoutes(router)
err = router.Run(fmt.Sprintf("%s:%v", *bindAddress, *bindPort))
if err != nil {
- log.Panicln("failed to create the server:", err)
+ log.WithError(err).Fatal("failed to create the server")
}
}
diff --git a/connection_handler.go b/connection_handler.go
index 2e2fa84..4fc2a95 100644
--- a/connection_handler.go
+++ b/connection_handler.go
@@ -1,95 +1,161 @@
package main
import (
- "context"
"encoding/binary"
- "fmt"
"github.com/flier/gohs/hyperscan"
"github.com/google/gopacket"
"github.com/google/gopacket/tcpassembly"
- "log"
+ log "github.com/sirupsen/logrus"
"sync"
"time"
)
+const initialConnectionsCapacity = 1024
+const initialScannersCapacity = 1024
+
type BiDirectionalStreamFactory struct {
- storage Storage
- serverIp gopacket.Endpoint
- connections map[StreamKey]ConnectionHandler
- mConnections sync.Mutex
- patterns hyperscan.StreamDatabase
- mPatterns sync.Mutex
- scratches []*hyperscan.Scratch
+ storage Storage
+ serverIp gopacket.Endpoint
+ connections map[StreamFlow]ConnectionHandler
+ mConnections sync.Mutex
+ rulesManager *RulesManager
+ rulesDatabase RulesDatabase
+ mRulesDatabase sync.Mutex
+ scanners []Scanner
}
-type StreamKey [4]gopacket.Endpoint
+type StreamFlow [4]gopacket.Endpoint
+
+type Scanner struct {
+ scratch *hyperscan.Scratch
+ version RowID
+}
type ConnectionHandler interface {
Complete(handler *StreamHandler)
Storage() Storage
- Context() context.Context
- Patterns() hyperscan.StreamDatabase
+ PatternsDatabase() hyperscan.StreamDatabase
}
type connectionHandlerImpl struct {
- storage Storage
- net, transport gopacket.Flow
- initiator StreamKey
- connectionKey string
+ factory *BiDirectionalStreamFactory
+ connectionFlow StreamFlow
mComplete sync.Mutex
otherStream *StreamHandler
- context context.Context
- patterns hyperscan.StreamDatabase
+}
+
+func NewBiDirectionalStreamFactory(storage Storage, serverIP gopacket.Endpoint,
+ rulesManager *RulesManager) *BiDirectionalStreamFactory {
+
+ factory := &BiDirectionalStreamFactory{
+ storage: storage,
+ serverIp: serverIP,
+ connections: make(map[StreamFlow]ConnectionHandler, initialConnectionsCapacity),
+ mConnections: sync.Mutex{},
+ rulesManager: rulesManager,
+ mRulesDatabase: sync.Mutex{},
+ scanners: make([]Scanner, 0, initialScannersCapacity),
+ }
+
+ go factory.updateRulesDatabaseService()
+ return factory
+}
+
+func (factory *BiDirectionalStreamFactory) updateRulesDatabaseService() {
+ for {
+ select {
+ case rulesDatabase, ok := <-factory.rulesManager.databaseUpdated:
+ if !ok {
+ return
+ }
+ factory.mRulesDatabase.Lock()
+ scanners := factory.scanners
+ factory.scanners = factory.scanners[:0]
+
+ for _, s := range scanners {
+ err := s.scratch.Realloc(rulesDatabase.database)
+ if err != nil {
+ log.WithError(err).Error("failed to realloc an existing scanner")
+ } else {
+ s.version = rulesDatabase.version
+ factory.scanners = append(factory.scanners, s)
+ }
+ }
+
+ factory.rulesDatabase = rulesDatabase
+ factory.mRulesDatabase.Unlock()
+ }
+ }
+}
+
+func (factory *BiDirectionalStreamFactory) takeScanner() Scanner {
+ factory.mRulesDatabase.Lock()
+ defer factory.mRulesDatabase.Unlock()
+
+ if len(factory.scanners) == 0 {
+ scratch, err := hyperscan.NewScratch(factory.rulesDatabase.database)
+ if err != nil {
+ log.WithError(err).Fatal("failed to alloc a new scratch")
+ }
+
+ return Scanner{
+ scratch: scratch,
+ version: factory.rulesDatabase.version,
+ }
+ }
+
+ index := len(factory.scanners) - 1
+ scanner := factory.scanners[index]
+ factory.scanners = factory.scanners[:index]
+
+ return scanner
+}
+
+func (factory *BiDirectionalStreamFactory) releaseScanner(scanner Scanner) {
+ factory.mRulesDatabase.Lock()
+ defer factory.mRulesDatabase.Unlock()
+
+ if scanner.version != factory.rulesDatabase.version {
+ err := scanner.scratch.Realloc(factory.rulesDatabase.database)
+ if err != nil {
+ log.WithError(err).Error("failed to realloc an existing scanner")
+ return
+ }
+ }
+ factory.scanners = append(factory.scanners, scanner)
}
func (factory *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream {
- key := StreamKey{net.Src(), net.Dst(), transport.Src(), transport.Dst()}
- invertedKey := StreamKey{net.Dst(), net.Src(), transport.Dst(), transport.Src()}
+ flow := StreamFlow{net.Src(), net.Dst(), transport.Src(), transport.Dst()}
+ invertedFlow := StreamFlow{net.Dst(), net.Src(), transport.Dst(), transport.Src()}
factory.mConnections.Lock()
- connection, isPresent := factory.connections[invertedKey]
+ connection, isPresent := factory.connections[invertedFlow]
if isPresent {
- delete(factory.connections, invertedKey)
+ delete(factory.connections, invertedFlow)
} else {
- var initiator StreamKey
+ var connectionFlow StreamFlow
if net.Src() == factory.serverIp {
- initiator = invertedKey
+ connectionFlow = invertedFlow
} else {
- initiator = key
+ connectionFlow = flow
}
connection = &connectionHandlerImpl{
- storage: factory.storage,
- net: net,
- transport: transport,
- initiator: initiator,
- mComplete: sync.Mutex{},
- context: context.Background(),
- patterns: factory.patterns,
+ connectionFlow: connectionFlow,
+ mComplete: sync.Mutex{},
+ factory: factory,
}
- factory.connections[key] = connection
+ factory.connections[flow] = connection
}
factory.mConnections.Unlock()
- streamHandler := NewStreamHandler(connection, key, factory.takeScratch())
+ streamHandler := NewStreamHandler(connection, flow, factory.takeScanner())
return &streamHandler
}
-func (factory *BiDirectionalStreamFactory) UpdatePatternsDatabase(database hyperscan.StreamDatabase) {
- factory.mPatterns.Lock()
- factory.patterns = database
-
- for _, s := range factory.scratches {
- err := s.Realloc(database)
- if err != nil {
- fmt.Println("failed to realloc an existing scratch")
- }
- }
-
- factory.mPatterns.Unlock()
-}
-
func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) {
+ ch.factory.releaseScanner(handler.scanner)
ch.mComplete.Lock()
if ch.otherStream == nil {
ch.otherStream = handler
@@ -112,7 +178,7 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) {
}
var client, server *StreamHandler
- if handler.streamKey == ch.initiator {
+ if handler.streamFlow == ch.connectionFlow {
client = handler
server = ch.otherStream
} else {
@@ -120,81 +186,48 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) {
server = handler
}
- ch.generateConnectionKey(startedAt)
-
- _, err := ch.storage.Insert("connections").Context(ch.context).One(OrderedDocument{
- {"_id", ch.connectionKey},
- {"ip_src", ch.initiator[0].String()},
- {"ip_dst", ch.initiator[1].String()},
- {"port_src", ch.initiator[2].String()},
- {"port_dst", ch.initiator[3].String()},
- {"started_at", startedAt},
- {"closed_at", closedAt},
- {"client_bytes", client.streamLength},
- {"server_bytes", server.streamLength},
- {"client_documents", len(client.documentsKeys)},
- {"server_documents", len(server.documentsKeys)},
- {"processed_at", time.Now()},
- })
+ connectionID := ch.Storage().NewCustomRowID(ch.connectionFlow.Hash(), startedAt)
+ connection := Connection{
+ ID: connectionID,
+ SourceIP: ch.connectionFlow[0].String(),
+ DestinationIP: ch.connectionFlow[1].String(),
+ SourcePort: binary.BigEndian.Uint16(ch.connectionFlow[2].Raw()),
+ DestinationPort: binary.BigEndian.Uint16(ch.connectionFlow[3].Raw()),
+ StartedAt: startedAt,
+ ClosedAt: closedAt,
+ ClientPackets: client.packetsCount,
+ ServerPackets: client.packetsCount,
+ ClientBytes: client.streamLength,
+ ServerBytes: server.streamLength,
+ ClientDocuments: len(client.documentsIDs),
+ ServerDocuments: len(server.documentsIDs),
+ ProcessedAt: time.Now(),
+ }
+ _, err := ch.Storage().Insert(Connections).One(connection)
if err != nil {
- log.Println("error inserting document on collection connections with _id = ", ch.connectionKey)
+ log.WithError(err).WithField("connection", connection).Error("failed to insert a connection")
+ return
}
- streamsIds := append(client.documentsKeys, server.documentsKeys...)
- n, err := ch.storage.Update("connection_streams").
- Context(ch.context).
+ streamsIds := append(client.documentsIDs, server.documentsIDs...)
+ n, err := ch.Storage().Update(ConnectionStreams).
Filter(OrderedDocument{{"_id", UnorderedDocument{"$in": streamsIds}}}).
- Many(UnorderedDocument{"connection_id": ch.connectionKey})
+ Many(UnorderedDocument{"connection_id": connectionID})
if err != nil {
- log.Println("failed to update connection streams", err)
- }
- if int(n) != len(streamsIds) {
- log.Println("failed to update all connections streams")
+ log.WithError(err).WithField("connection", connection).Error("failed to update connection streams")
+ } else if int(n) != len(streamsIds) {
+ log.WithError(err).WithField("connection", connection).Error("failed to update all connections streams")
}
}
func (ch *connectionHandlerImpl) Storage() Storage {
- return ch.storage
-}
-
-func (ch *connectionHandlerImpl) Context() context.Context {
- return ch.context
-}
-
-func (ch *connectionHandlerImpl) Patterns() hyperscan.StreamDatabase {
- return ch.patterns
+ return ch.factory.storage
}
-func (ch *connectionHandlerImpl) generateConnectionKey(firstPacketSeen time.Time) {
- hash := make([]byte, 16)
- binary.BigEndian.PutUint64(hash, uint64(firstPacketSeen.UnixNano()))
- binary.BigEndian.PutUint64(hash[8:], ch.net.FastHash()^ch.transport.FastHash())
-
- ch.connectionKey = fmt.Sprintf("%x", hash)
-}
-
-func (factory *BiDirectionalStreamFactory) takeScratch() *hyperscan.Scratch {
- factory.mPatterns.Lock()
- defer factory.mPatterns.Unlock()
-
- if len(factory.scratches) == 0 {
- scratch, err := hyperscan.NewScratch(factory.patterns)
- if err != nil {
- fmt.Println("failed to alloc a new scratch")
- }
-
- return scratch
- }
-
- index := len(factory.scratches) - 1
- scratch := factory.scratches[index]
- factory.scratches = factory.scratches[:index]
-
- return scratch
+func (ch *connectionHandlerImpl) PatternsDatabase() hyperscan.StreamDatabase {
+ return ch.factory.rulesDatabase.database
}
-func (factory *BiDirectionalStreamFactory) releaseScratch(scratch *hyperscan.Scratch) {
- factory.mPatterns.Lock()
- factory.scratches = append(factory.scratches, scratch)
- factory.mPatterns.Unlock()
+func (sf StreamFlow) Hash() uint64 {
+ return sf[0].FastHash() ^ sf[1].FastHash() ^ sf[2].FastHash() ^ sf[3].FastHash()
}
diff --git a/connections.go b/connections.go
new file mode 100644
index 0000000..e3adaf2
--- /dev/null
+++ b/connections.go
@@ -0,0 +1,20 @@
+package main
+
+import "time"
+
+type Connection struct {
+ ID RowID `json:"id" bson:"_id"`
+ SourceIP string `json:"ip_src" bson:"ip_src"`
+ DestinationIP string `json:"ip_dst" bson:"ip_dst"`
+ SourcePort uint16 `json:"port_src" bson:"port_src"`
+ DestinationPort uint16 `json:"port_dst" bson:"port_dst"`
+ StartedAt time.Time `json:"started_at" bson:"started_at"`
+ ClosedAt time.Time `json:"closed_at" bson:"closed_at"`
+ ClientPackets int `json:"client_packets" bson:"client_packets"`
+ ServerPackets int `json:"server_packets" bson:"server_packets"`
+ ClientBytes int `json:"client_bytes" bson:"client_bytes"`
+ ServerBytes int `json:"server_bytes" bson:"server_bytes"`
+ ClientDocuments int `json:"client_documents" bson:"client_documents"`
+ ServerDocuments int `json:"server_documents" bson:"server_documents"`
+ ProcessedAt time.Time `json:"processed_at" bson:"processed_at"`
+}
diff --git a/rules_manager.go b/rules_manager.go
index e5a8d38..69439e0 100644
--- a/rules_manager.go
+++ b/rules_manager.go
@@ -52,6 +52,11 @@ type Rule struct {
Version int64 `json:"version" bson:"version"`
}
+type RulesDatabase struct {
+ database hyperscan.StreamDatabase
+ version RowID
+}
+
type RulesManager struct {
storage Storage
rules map[string]Rule
@@ -59,7 +64,7 @@ type RulesManager struct {
ruleIndex int
patterns map[string]Pattern
mPatterns sync.Mutex
- databaseUpdated chan interface{}
+ databaseUpdated chan RulesDatabase
}
func NewRulesManager(storage Storage) RulesManager {
@@ -153,7 +158,10 @@ func (rm RulesManager) generateDatabase(version RowID) error {
return err
}
- rm.databaseUpdated <- database
+ rm.databaseUpdated <- RulesDatabase{
+ database: database,
+ version: version,
+ }
return nil
}
diff --git a/stream_handler.go b/stream_handler.go
index 3fafa21..2d80f60 100644
--- a/stream_handler.go
+++ b/stream_handler.go
@@ -18,7 +18,7 @@ const InitialPatternSliceSize = 8
// method:
type StreamHandler struct {
connection ConnectionHandler
- streamKey StreamKey
+ streamFlow StreamFlow
buffer *bytes.Buffer
indexes []int
timestamps []time.Time
@@ -26,28 +26,31 @@ type StreamHandler struct {
currentIndex int
firstPacketSeen time.Time
lastPacketSeen time.Time
- documentsKeys []RowID
+ documentsIDs []RowID
streamLength int
+ packetsCount int
patternStream hyperscan.Stream
patternMatches map[uint][]PatternSlice
+ scanner Scanner
}
// NewReaderStream returns a new StreamHandler object.
-func NewStreamHandler(connection ConnectionHandler, key StreamKey, scratch *hyperscan.Scratch) StreamHandler {
+func NewStreamHandler(connection ConnectionHandler, streamFlow StreamFlow, scanner Scanner) StreamHandler {
handler := StreamHandler{
connection: connection,
- streamKey: key,
+ streamFlow: streamFlow,
buffer: new(bytes.Buffer),
indexes: make([]int, 0, InitialBlockCount),
timestamps: make([]time.Time, 0, InitialBlockCount),
lossBlocks: make([]bool, 0, InitialBlockCount),
- documentsKeys: make([]RowID, 0, 1), // most of the time the stream fit in one document
+ documentsIDs: make([]RowID, 0, 1), // most of the time the stream fit in one document
patternMatches: make(map[uint][]PatternSlice, 10), // TODO: change with exactly value
+ scanner: scanner,
}
- stream, err := connection.Patterns().Open(0, scratch, handler.onMatch, nil)
+ stream, err := connection.PatternsDatabase().Open(0, scanner.scratch, handler.onMatch, nil)
if err != nil {
- log.WithField("streamKey", key).WithError(err).Error("failed to create a stream")
+ log.WithField("streamFlow", streamFlow).WithError(err).Error("failed to create a stream")
}
handler.patternStream = stream
@@ -57,6 +60,8 @@ func NewStreamHandler(connection ConnectionHandler, key StreamKey, scratch *hype
// Reassembled implements tcpassembly.Stream's Reassembled function.
func (sh *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) {
for _, r := range reassembly {
+ sh.packetsCount++
+
skip := r.Skip
if r.Start {
skip = 0
@@ -78,7 +83,7 @@ func (sh *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) {
n, err := sh.buffer.Write(r.Bytes[skip:])
if err != nil {
log.WithError(err).Error("failed to copy bytes from a Reassemble")
- return
+ continue
}
sh.indexes = append(sh.indexes, sh.currentIndex)
sh.timestamps = append(sh.timestamps, r.Seen)
@@ -86,18 +91,22 @@ func (sh *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) {
sh.currentIndex += n
sh.streamLength += n
- err = sh.patternStream.Scan(r.Bytes)
- if err != nil {
- log.WithError(err).Error("failed to scan packet buffer")
+ if sh.patternStream != nil {
+ err = sh.patternStream.Scan(r.Bytes)
+ if err != nil {
+ log.WithError(err).Error("failed to scan packet buffer")
+ }
}
}
}
// ReassemblyComplete implements tcpassembly.Stream's ReassemblyComplete function.
func (sh *StreamHandler) ReassemblyComplete() {
- err := sh.patternStream.Close()
- if err != nil {
- log.WithError(err).Error("failed to close pattern stream")
+ if sh.patternStream != nil {
+ err := sh.patternStream.Close()
+ if err != nil {
+ log.WithError(err).Error("failed to close pattern stream")
+ }
}
if sh.currentIndex > 0 {
@@ -140,16 +149,14 @@ func (sh *StreamHandler) onMatch(id uint, from uint64, to uint64, flags uint, co
}
func (sh *StreamHandler) storageCurrentDocument() {
- payload := (sh.streamKey[0].FastHash()^sh.streamKey[1].FastHash()^sh.streamKey[2].FastHash()^
- sh.streamKey[3].FastHash())&uint64(0xffffffffffffff00) | uint64(len(sh.documentsKeys)) // LOL
- streamKey := sh.connection.Storage().NewCustomRowID(payload, sh.firstPacketSeen)
+ payload := sh.streamFlow.Hash()&uint64(0xffffffffffffff00) | uint64(len(sh.documentsIDs)) // LOL
+ streamID := sh.connection.Storage().NewCustomRowID(payload, sh.firstPacketSeen)
_, err := sh.connection.Storage().Insert(ConnectionStreams).
- Context(sh.connection.Context()).
One(ConnectionStream{
- ID: streamKey,
+ ID: streamID,
ConnectionID: ZeroRowID,
- DocumentIndex: len(sh.documentsKeys),
+ DocumentIndex: len(sh.documentsIDs),
Payload: sh.buffer.Bytes(),
BlocksIndexes: sh.indexes,
BlocksTimestamps: sh.timestamps,
@@ -159,7 +166,7 @@ func (sh *StreamHandler) storageCurrentDocument() {
if err != nil {
log.WithError(err).Error("failed to insert connection stream")
+ } else {
+ sh.documentsIDs = append(sh.documentsIDs, streamID)
}
-
- sh.documentsKeys = append(sh.documentsKeys, streamKey)
}
diff --git a/stream_handler_test.go b/stream_handler_test.go
index ece3190..962048f 100644
--- a/stream_handler_test.go
+++ b/stream_handler_test.go
@@ -39,7 +39,7 @@ func TestReassemblingEmptyStream(t *testing.T) {
assert.Zero(t, streamHandler.currentIndex)
assert.Zero(t, streamHandler.firstPacketSeen)
assert.Zero(t, streamHandler.lastPacketSeen)
- assert.Len(t, streamHandler.documentsKeys, 0)
+ assert.Len(t, streamHandler.documentsIDs, 0)
assert.Zero(t, streamHandler.streamLength)
assert.Len(t, streamHandler.patternMatches, 0)
@@ -125,7 +125,7 @@ func TestReassemblingSingleDocument(t *testing.T) {
assert.Equal(t, len(data), streamHandler.currentIndex)
assert.Equal(t, firstTime, streamHandler.firstPacketSeen)
assert.Equal(t, lastTime, streamHandler.lastPacketSeen)
- assert.Len(t, streamHandler.documentsKeys, 1)
+ assert.Len(t, streamHandler.documentsIDs, 1)
assert.Equal(t, len(data), streamHandler.streamLength)
assert.Len(t, streamHandler.patternMatches, 0)
@@ -210,7 +210,7 @@ func TestReassemblingMultipleDocuments(t *testing.T) {
assert.Equal(t, MaxDocumentSize, streamHandler.currentIndex)
assert.Equal(t, firstTime, streamHandler.firstPacketSeen)
assert.Equal(t, lastTime, streamHandler.lastPacketSeen)
- assert.Len(t, streamHandler.documentsKeys, 2)
+ assert.Len(t, streamHandler.documentsIDs, 2)
assert.Equal(t, len(data), streamHandler.streamLength)
assert.Len(t, streamHandler.patternMatches, 0)
@@ -287,7 +287,7 @@ func TestReassemblingPatternMatching(t *testing.T) {
assert.Equal(t, len(payload), streamHandler.currentIndex)
assert.Equal(t, seen, streamHandler.firstPacketSeen)
assert.Equal(t, seen, streamHandler.lastPacketSeen)
- assert.Len(t, streamHandler.documentsKeys, 1)
+ assert.Len(t, streamHandler.documentsIDs, 1)
assert.Equal(t, len(payload), streamHandler.streamLength)
assert.Equal(t, true, completed, "completed")
@@ -310,7 +310,8 @@ func createTestStreamHandler(wrapper *TestStorageWrapper, patterns hyperscan.Str
srcPort := layers.NewTCPPortEndpoint(srcPort)
dstPort := layers.NewTCPPortEndpoint(dstPort)
- return NewStreamHandler(testConnectionHandler, StreamKey{srcIp, dstIp, srcPort, dstPort}, scratch)
+ scanner := Scanner{scratch: scratch, version: ZeroRowID}
+ return NewStreamHandler(testConnectionHandler, StreamFlow{srcIp, dstIp, srcPort, dstPort}, scanner)
}
type testConnectionHandler struct {
@@ -327,7 +328,7 @@ func (tch *testConnectionHandler) Context() context.Context {
return tch.wrapper.Context
}
-func (tch *testConnectionHandler) Patterns() hyperscan.StreamDatabase {
+func (tch *testConnectionHandler) PatternsDatabase() hyperscan.StreamDatabase {
return tch.patterns
}