aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--connection_handler.go23
-rw-r--r--connection_handler_test.go146
-rw-r--r--rules_manager.go24
3 files changed, 176 insertions, 17 deletions
diff --git a/connection_handler.go b/connection_handler.go
index de30634..dc23315 100644
--- a/connection_handler.go
+++ b/connection_handler.go
@@ -18,7 +18,7 @@ type BiDirectionalStreamFactory struct {
serverIP gopacket.Endpoint
connections map[StreamFlow]ConnectionHandler
mConnections sync.Mutex
- rulesManager *RulesManager
+ rulesManager RulesManager
rulesDatabase RulesDatabase
mRulesDatabase sync.Mutex
scanners []Scanner
@@ -46,7 +46,7 @@ type connectionHandlerImpl struct {
}
func NewBiDirectionalStreamFactory(storage Storage, serverIP gopacket.Endpoint,
- rulesManager *RulesManager) *BiDirectionalStreamFactory {
+ rulesManager RulesManager) *BiDirectionalStreamFactory {
factory := &BiDirectionalStreamFactory{
storage: storage,
@@ -65,7 +65,7 @@ func NewBiDirectionalStreamFactory(storage Storage, serverIP gopacket.Endpoint,
func (factory *BiDirectionalStreamFactory) updateRulesDatabaseService() {
for {
select {
- case rulesDatabase, ok := <-factory.rulesManager.databaseUpdated:
+ case rulesDatabase, ok := <-factory.rulesManager.DatabaseUpdateChannel():
if !ok {
return
}
@@ -122,6 +122,7 @@ func (factory *BiDirectionalStreamFactory) releaseScanner(scanner Scanner) {
log.WithError(err).Error("failed to realloc an existing scanner")
return
}
+ scanner.version = factory.rulesDatabase.version
}
factory.scanners = append(factory.scanners, scanner)
}
@@ -213,13 +214,15 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) {
}
streamsIDs := append(client.documentsIDs, server.documentsIDs...)
- n, err := ch.Storage().Update(ConnectionStreams).
- Filter(OrderedDocument{{"_id", UnorderedDocument{"$in": streamsIDs}}}).
- Many(UnorderedDocument{"connection_id": connectionID})
- if err != nil {
- log.WithError(err).WithField("connection", connection).Error("failed to update connection streams")
- } else if int(n) != len(streamsIDs) {
- log.WithError(err).WithField("connection", connection).Error("failed to update all connections streams")
+ if len(streamsIDs) > 0 {
+ n, err := ch.Storage().Update(ConnectionStreams).
+ Filter(OrderedDocument{{"_id", UnorderedDocument{"$in": streamsIDs}}}).
+ Many(UnorderedDocument{"connection_id": connectionID})
+ if err != nil {
+ log.WithError(err).WithField("connection", connection).Error("failed to update connection streams")
+ } else if int(n) != len(streamsIDs) {
+ log.WithError(err).WithField("connection", connection).Error("failed to update all connections streams")
+ }
}
}
diff --git a/connection_handler_test.go b/connection_handler_test.go
new file mode 100644
index 0000000..51fa0bc
--- /dev/null
+++ b/connection_handler_test.go
@@ -0,0 +1,146 @@
+package main
+
+import (
+ "context"
+ "encoding/binary"
+ "github.com/flier/gohs/hyperscan"
+ "github.com/google/gopacket"
+ "github.com/google/gopacket/layers"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "net"
+ "testing"
+ "time"
+)
+
+func TestTakeReleaseScanners(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP))
+ ruleManager := TestRuleManager{
+ databaseUpdated: make(chan RulesDatabase),
+ }
+
+ database, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0))
+ require.NoError(t, err)
+
+ factory := NewBiDirectionalStreamFactory(wrapper.Storage, serverIP, &ruleManager)
+ version := wrapper.Storage.NewRowID()
+ ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
+ time.Sleep(10 * time.Millisecond)
+
+ n := 100
+ for i := 0; i < n; i++ {
+ scanner := factory.takeScanner()
+ assert.Equal(t, scanner.version, version)
+
+ if i%5 == 0 {
+ version = wrapper.Storage.NewRowID()
+ ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
+ time.Sleep(10 * time.Millisecond)
+ }
+ factory.releaseScanner(scanner)
+ }
+ assert.Len(t, factory.scanners, 1)
+
+ scanners := make([]Scanner, n)
+ for i := 0; i < n; i++ {
+ scanners[i] = factory.takeScanner()
+ assert.Equal(t, scanners[i].version, version)
+ }
+
+ version = wrapper.Storage.NewRowID()
+ ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
+ time.Sleep(10 * time.Millisecond)
+
+ for i := 0; i < n; i++ {
+ factory.releaseScanner(scanners[i])
+ }
+ assert.Len(t, factory.scanners, n)
+
+ for i := 0; i < n; i++ {
+ scanners[i] = factory.takeScanner()
+ assert.Equal(t, scanners[i].version, version)
+ factory.releaseScanner(scanners[i])
+ }
+
+ wrapper.Destroy(t)
+}
+
+
+func TestConnectionFactory(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ wrapper.AddCollection(Connections)
+ wrapper.AddCollection(ConnectionStreams)
+
+ ruleManager := TestRuleManager{
+ databaseUpdated: make(chan RulesDatabase),
+ }
+
+ serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP))
+ clientIP := layers.NewIPEndpoint(net.ParseIP(testSrcIP))
+ serverPort := layers.NewTCPPortEndpoint(dstPort)
+ clientPort := layers.NewTCPPortEndpoint(srcPort)
+ serverClientNetFlow, err := gopacket.FlowFromEndpoints(serverIP, clientIP)
+ require.NoError(t, err)
+ serverClientTransportFlow, err := gopacket.FlowFromEndpoints(serverPort, clientPort)
+ require.NoError(t, err)
+ clientServerNetFlow, err := gopacket.FlowFromEndpoints(clientIP, serverIP)
+ require.NoError(t, err)
+ clientServerTransportFlow, err := gopacket.FlowFromEndpoints(clientPort, serverPort)
+ require.NoError(t, err)
+
+ database, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0))
+ require.NoError(t, err)
+
+ factory := NewBiDirectionalStreamFactory(wrapper.Storage, serverIP, &ruleManager)
+ version := wrapper.Storage.NewRowID()
+ ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
+ time.Sleep(10 * time.Millisecond)
+
+ serverStream := factory.New(serverClientNetFlow, serverClientTransportFlow)
+ connectionFlow := StreamFlow{clientIP, serverIP, clientPort, serverPort}
+ invertedConnectionFlow := StreamFlow{serverIP, clientIP, serverPort, clientPort}
+ connection, isPresent := factory.connections[invertedConnectionFlow]
+ require.True(t, isPresent)
+ assert.Equal(t, connectionFlow, connection.(*connectionHandlerImpl).connectionFlow)
+
+ serverStream.ReassemblyComplete()
+ assert.Equal(t, invertedConnectionFlow, connection.(*connectionHandlerImpl).otherStream.streamFlow)
+
+ clientStream := factory.New(clientServerNetFlow, clientServerTransportFlow)
+ assert.Len(t, factory.connections, 0)
+ clientStream.ReassemblyComplete()
+
+ var result Connection
+ err = wrapper.Storage.Find(Connections).Context(wrapper.Context).First(&result)
+ require.NoError(t, err)
+
+ assert.NotNil(t, result)
+ assert.Equal(t, wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), result.StartedAt), result.ID)
+ assert.Equal(t, clientIP.String(), result.SourceIP)
+ assert.Equal(t, serverIP.String(), result.DestinationIP)
+ assert.Equal(t, binary.BigEndian.Uint16(clientPort.Raw()), result.SourcePort)
+ assert.Equal(t, binary.BigEndian.Uint16(serverPort.Raw()), result.DestinationPort)
+
+ wrapper.Destroy(t)
+}
+
+
+type TestRuleManager struct {
+ databaseUpdated chan RulesDatabase
+}
+
+func (rm TestRuleManager) LoadRules() error {
+ return nil
+}
+
+func (rm TestRuleManager) AddRule(_ context.Context, _ Rule) (string, error) {
+ return "", nil
+}
+
+func (rm TestRuleManager) FillWithMatchedRules(_ *Connection, _ map[uint][]PatternSlice, _ map[uint][]PatternSlice) {
+}
+
+func (rm TestRuleManager) DatabaseUpdateChannel() chan RulesDatabase {
+ return rm.databaseUpdated
+}
diff --git a/rules_manager.go b/rules_manager.go
index 482188e..0750d53 100644
--- a/rules_manager.go
+++ b/rules_manager.go
@@ -58,7 +58,14 @@ type RulesDatabase struct {
version RowID
}
-type RulesManager struct {
+type RulesManager interface {
+ LoadRules() error
+ AddRule(context context.Context, rule Rule) (string, error)
+ FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice)
+ DatabaseUpdateChannel() chan RulesDatabase
+}
+
+type rulesManagerImpl struct {
storage Storage
rules map[string]Rule
rulesByName map[string]Rule
@@ -69,7 +76,7 @@ type RulesManager struct {
}
func NewRulesManager(storage Storage) RulesManager {
- return RulesManager{
+ return &rulesManagerImpl{
storage: storage,
rules: make(map[string]Rule),
patterns: make(map[string]Pattern),
@@ -77,7 +84,7 @@ func NewRulesManager(storage Storage) RulesManager {
}
}
-func (rm RulesManager) LoadRules() error {
+func (rm rulesManagerImpl) LoadRules() error {
var rules []Rule
if err := rm.storage.Find(Rules).Sort("_id", true).All(&rules); err != nil {
return err
@@ -93,7 +100,7 @@ func (rm RulesManager) LoadRules() error {
return rm.generateDatabase(rules[len(rules)-1].ID)
}
-func (rm RulesManager) AddRule(context context.Context, rule Rule) (string, error) {
+func (rm rulesManagerImpl) AddRule(context context.Context, rule Rule) (string, error) {
rm.mPatterns.Lock()
rule.ID = rm.storage.NewCustomRowID(uint64(rm.ruleIndex), time.Now())
@@ -117,7 +124,7 @@ func (rm RulesManager) AddRule(context context.Context, rule Rule) (string, erro
return rule.ID.Hex(), nil
}
-func (rm RulesManager) validateAndAddRuleLocal(rule *Rule) error {
+func (rm rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error {
if _, alreadyPresent := rm.rulesByName[rule.Name]; alreadyPresent {
return errors.New("rule name must be unique")
}
@@ -147,7 +154,7 @@ func (rm RulesManager) validateAndAddRuleLocal(rule *Rule) error {
return nil
}
-func (rm RulesManager) generateDatabase(version RowID) error {
+func (rm rulesManagerImpl) generateDatabase(version RowID) error {
patterns := make([]*hyperscan.Pattern, len(rm.patterns))
var i int
for _, pattern := range rm.patterns {
@@ -167,9 +174,12 @@ func (rm RulesManager) generateDatabase(version RowID) error {
return nil
}
-func (rm RulesManager) FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice,
+func (rm rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice,
serverMatches map[uint][]PatternSlice) {
+}
+func (rm rulesManagerImpl) DatabaseUpdateChannel() chan RulesDatabase {
+ return rm.databaseUpdated
}
func (p Pattern) BuildPattern() error {