diff options
-rw-r--r-- | connection_handler.go | 23 | ||||
-rw-r--r-- | connection_handler_test.go | 146 | ||||
-rw-r--r-- | rules_manager.go | 24 |
3 files changed, 176 insertions, 17 deletions
diff --git a/connection_handler.go b/connection_handler.go index de30634..dc23315 100644 --- a/connection_handler.go +++ b/connection_handler.go @@ -18,7 +18,7 @@ type BiDirectionalStreamFactory struct { serverIP gopacket.Endpoint connections map[StreamFlow]ConnectionHandler mConnections sync.Mutex - rulesManager *RulesManager + rulesManager RulesManager rulesDatabase RulesDatabase mRulesDatabase sync.Mutex scanners []Scanner @@ -46,7 +46,7 @@ type connectionHandlerImpl struct { } func NewBiDirectionalStreamFactory(storage Storage, serverIP gopacket.Endpoint, - rulesManager *RulesManager) *BiDirectionalStreamFactory { + rulesManager RulesManager) *BiDirectionalStreamFactory { factory := &BiDirectionalStreamFactory{ storage: storage, @@ -65,7 +65,7 @@ func NewBiDirectionalStreamFactory(storage Storage, serverIP gopacket.Endpoint, func (factory *BiDirectionalStreamFactory) updateRulesDatabaseService() { for { select { - case rulesDatabase, ok := <-factory.rulesManager.databaseUpdated: + case rulesDatabase, ok := <-factory.rulesManager.DatabaseUpdateChannel(): if !ok { return } @@ -122,6 +122,7 @@ func (factory *BiDirectionalStreamFactory) releaseScanner(scanner Scanner) { log.WithError(err).Error("failed to realloc an existing scanner") return } + scanner.version = factory.rulesDatabase.version } factory.scanners = append(factory.scanners, scanner) } @@ -213,13 +214,15 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) { } streamsIDs := append(client.documentsIDs, server.documentsIDs...) - n, err := ch.Storage().Update(ConnectionStreams). - Filter(OrderedDocument{{"_id", UnorderedDocument{"$in": streamsIDs}}}). - Many(UnorderedDocument{"connection_id": connectionID}) - if err != nil { - log.WithError(err).WithField("connection", connection).Error("failed to update connection streams") - } else if int(n) != len(streamsIDs) { - log.WithError(err).WithField("connection", connection).Error("failed to update all connections streams") + if len(streamsIDs) > 0 { + n, err := ch.Storage().Update(ConnectionStreams). + Filter(OrderedDocument{{"_id", UnorderedDocument{"$in": streamsIDs}}}). + Many(UnorderedDocument{"connection_id": connectionID}) + if err != nil { + log.WithError(err).WithField("connection", connection).Error("failed to update connection streams") + } else if int(n) != len(streamsIDs) { + log.WithError(err).WithField("connection", connection).Error("failed to update all connections streams") + } } } diff --git a/connection_handler_test.go b/connection_handler_test.go new file mode 100644 index 0000000..51fa0bc --- /dev/null +++ b/connection_handler_test.go @@ -0,0 +1,146 @@ +package main + +import ( + "context" + "encoding/binary" + "github.com/flier/gohs/hyperscan" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "net" + "testing" + "time" +) + +func TestTakeReleaseScanners(t *testing.T) { + wrapper := NewTestStorageWrapper(t) + serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP)) + ruleManager := TestRuleManager{ + databaseUpdated: make(chan RulesDatabase), + } + + database, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0)) + require.NoError(t, err) + + factory := NewBiDirectionalStreamFactory(wrapper.Storage, serverIP, &ruleManager) + version := wrapper.Storage.NewRowID() + ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} + time.Sleep(10 * time.Millisecond) + + n := 100 + for i := 0; i < n; i++ { + scanner := factory.takeScanner() + assert.Equal(t, scanner.version, version) + + if i%5 == 0 { + version = wrapper.Storage.NewRowID() + ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} + time.Sleep(10 * time.Millisecond) + } + factory.releaseScanner(scanner) + } + assert.Len(t, factory.scanners, 1) + + scanners := make([]Scanner, n) + for i := 0; i < n; i++ { + scanners[i] = factory.takeScanner() + assert.Equal(t, scanners[i].version, version) + } + + version = wrapper.Storage.NewRowID() + ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} + time.Sleep(10 * time.Millisecond) + + for i := 0; i < n; i++ { + factory.releaseScanner(scanners[i]) + } + assert.Len(t, factory.scanners, n) + + for i := 0; i < n; i++ { + scanners[i] = factory.takeScanner() + assert.Equal(t, scanners[i].version, version) + factory.releaseScanner(scanners[i]) + } + + wrapper.Destroy(t) +} + + +func TestConnectionFactory(t *testing.T) { + wrapper := NewTestStorageWrapper(t) + wrapper.AddCollection(Connections) + wrapper.AddCollection(ConnectionStreams) + + ruleManager := TestRuleManager{ + databaseUpdated: make(chan RulesDatabase), + } + + serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP)) + clientIP := layers.NewIPEndpoint(net.ParseIP(testSrcIP)) + serverPort := layers.NewTCPPortEndpoint(dstPort) + clientPort := layers.NewTCPPortEndpoint(srcPort) + serverClientNetFlow, err := gopacket.FlowFromEndpoints(serverIP, clientIP) + require.NoError(t, err) + serverClientTransportFlow, err := gopacket.FlowFromEndpoints(serverPort, clientPort) + require.NoError(t, err) + clientServerNetFlow, err := gopacket.FlowFromEndpoints(clientIP, serverIP) + require.NoError(t, err) + clientServerTransportFlow, err := gopacket.FlowFromEndpoints(clientPort, serverPort) + require.NoError(t, err) + + database, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0)) + require.NoError(t, err) + + factory := NewBiDirectionalStreamFactory(wrapper.Storage, serverIP, &ruleManager) + version := wrapper.Storage.NewRowID() + ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} + time.Sleep(10 * time.Millisecond) + + serverStream := factory.New(serverClientNetFlow, serverClientTransportFlow) + connectionFlow := StreamFlow{clientIP, serverIP, clientPort, serverPort} + invertedConnectionFlow := StreamFlow{serverIP, clientIP, serverPort, clientPort} + connection, isPresent := factory.connections[invertedConnectionFlow] + require.True(t, isPresent) + assert.Equal(t, connectionFlow, connection.(*connectionHandlerImpl).connectionFlow) + + serverStream.ReassemblyComplete() + assert.Equal(t, invertedConnectionFlow, connection.(*connectionHandlerImpl).otherStream.streamFlow) + + clientStream := factory.New(clientServerNetFlow, clientServerTransportFlow) + assert.Len(t, factory.connections, 0) + clientStream.ReassemblyComplete() + + var result Connection + err = wrapper.Storage.Find(Connections).Context(wrapper.Context).First(&result) + require.NoError(t, err) + + assert.NotNil(t, result) + assert.Equal(t, wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), result.StartedAt), result.ID) + assert.Equal(t, clientIP.String(), result.SourceIP) + assert.Equal(t, serverIP.String(), result.DestinationIP) + assert.Equal(t, binary.BigEndian.Uint16(clientPort.Raw()), result.SourcePort) + assert.Equal(t, binary.BigEndian.Uint16(serverPort.Raw()), result.DestinationPort) + + wrapper.Destroy(t) +} + + +type TestRuleManager struct { + databaseUpdated chan RulesDatabase +} + +func (rm TestRuleManager) LoadRules() error { + return nil +} + +func (rm TestRuleManager) AddRule(_ context.Context, _ Rule) (string, error) { + return "", nil +} + +func (rm TestRuleManager) FillWithMatchedRules(_ *Connection, _ map[uint][]PatternSlice, _ map[uint][]PatternSlice) { +} + +func (rm TestRuleManager) DatabaseUpdateChannel() chan RulesDatabase { + return rm.databaseUpdated +} diff --git a/rules_manager.go b/rules_manager.go index 482188e..0750d53 100644 --- a/rules_manager.go +++ b/rules_manager.go @@ -58,7 +58,14 @@ type RulesDatabase struct { version RowID } -type RulesManager struct { +type RulesManager interface { + LoadRules() error + AddRule(context context.Context, rule Rule) (string, error) + FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice) + DatabaseUpdateChannel() chan RulesDatabase +} + +type rulesManagerImpl struct { storage Storage rules map[string]Rule rulesByName map[string]Rule @@ -69,7 +76,7 @@ type RulesManager struct { } func NewRulesManager(storage Storage) RulesManager { - return RulesManager{ + return &rulesManagerImpl{ storage: storage, rules: make(map[string]Rule), patterns: make(map[string]Pattern), @@ -77,7 +84,7 @@ func NewRulesManager(storage Storage) RulesManager { } } -func (rm RulesManager) LoadRules() error { +func (rm rulesManagerImpl) LoadRules() error { var rules []Rule if err := rm.storage.Find(Rules).Sort("_id", true).All(&rules); err != nil { return err @@ -93,7 +100,7 @@ func (rm RulesManager) LoadRules() error { return rm.generateDatabase(rules[len(rules)-1].ID) } -func (rm RulesManager) AddRule(context context.Context, rule Rule) (string, error) { +func (rm rulesManagerImpl) AddRule(context context.Context, rule Rule) (string, error) { rm.mPatterns.Lock() rule.ID = rm.storage.NewCustomRowID(uint64(rm.ruleIndex), time.Now()) @@ -117,7 +124,7 @@ func (rm RulesManager) AddRule(context context.Context, rule Rule) (string, erro return rule.ID.Hex(), nil } -func (rm RulesManager) validateAndAddRuleLocal(rule *Rule) error { +func (rm rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error { if _, alreadyPresent := rm.rulesByName[rule.Name]; alreadyPresent { return errors.New("rule name must be unique") } @@ -147,7 +154,7 @@ func (rm RulesManager) validateAndAddRuleLocal(rule *Rule) error { return nil } -func (rm RulesManager) generateDatabase(version RowID) error { +func (rm rulesManagerImpl) generateDatabase(version RowID) error { patterns := make([]*hyperscan.Pattern, len(rm.patterns)) var i int for _, pattern := range rm.patterns { @@ -167,9 +174,12 @@ func (rm RulesManager) generateDatabase(version RowID) error { return nil } -func (rm RulesManager) FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, +func (rm rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice) { +} +func (rm rulesManagerImpl) DatabaseUpdateChannel() chan RulesDatabase { + return rm.databaseUpdated } func (p Pattern) BuildPattern() error { |