aboutsummaryrefslogtreecommitdiff
path: root/connection_handler_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'connection_handler_test.go')
-rw-r--r--connection_handler_test.go146
1 files changed, 146 insertions, 0 deletions
diff --git a/connection_handler_test.go b/connection_handler_test.go
new file mode 100644
index 0000000..51fa0bc
--- /dev/null
+++ b/connection_handler_test.go
@@ -0,0 +1,146 @@
+package main
+
+import (
+ "context"
+ "encoding/binary"
+ "github.com/flier/gohs/hyperscan"
+ "github.com/google/gopacket"
+ "github.com/google/gopacket/layers"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "net"
+ "testing"
+ "time"
+)
+
+func TestTakeReleaseScanners(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP))
+ ruleManager := TestRuleManager{
+ databaseUpdated: make(chan RulesDatabase),
+ }
+
+ database, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0))
+ require.NoError(t, err)
+
+ factory := NewBiDirectionalStreamFactory(wrapper.Storage, serverIP, &ruleManager)
+ version := wrapper.Storage.NewRowID()
+ ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
+ time.Sleep(10 * time.Millisecond)
+
+ n := 100
+ for i := 0; i < n; i++ {
+ scanner := factory.takeScanner()
+ assert.Equal(t, scanner.version, version)
+
+ if i%5 == 0 {
+ version = wrapper.Storage.NewRowID()
+ ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
+ time.Sleep(10 * time.Millisecond)
+ }
+ factory.releaseScanner(scanner)
+ }
+ assert.Len(t, factory.scanners, 1)
+
+ scanners := make([]Scanner, n)
+ for i := 0; i < n; i++ {
+ scanners[i] = factory.takeScanner()
+ assert.Equal(t, scanners[i].version, version)
+ }
+
+ version = wrapper.Storage.NewRowID()
+ ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
+ time.Sleep(10 * time.Millisecond)
+
+ for i := 0; i < n; i++ {
+ factory.releaseScanner(scanners[i])
+ }
+ assert.Len(t, factory.scanners, n)
+
+ for i := 0; i < n; i++ {
+ scanners[i] = factory.takeScanner()
+ assert.Equal(t, scanners[i].version, version)
+ factory.releaseScanner(scanners[i])
+ }
+
+ wrapper.Destroy(t)
+}
+
+
+func TestConnectionFactory(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ wrapper.AddCollection(Connections)
+ wrapper.AddCollection(ConnectionStreams)
+
+ ruleManager := TestRuleManager{
+ databaseUpdated: make(chan RulesDatabase),
+ }
+
+ serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP))
+ clientIP := layers.NewIPEndpoint(net.ParseIP(testSrcIP))
+ serverPort := layers.NewTCPPortEndpoint(dstPort)
+ clientPort := layers.NewTCPPortEndpoint(srcPort)
+ serverClientNetFlow, err := gopacket.FlowFromEndpoints(serverIP, clientIP)
+ require.NoError(t, err)
+ serverClientTransportFlow, err := gopacket.FlowFromEndpoints(serverPort, clientPort)
+ require.NoError(t, err)
+ clientServerNetFlow, err := gopacket.FlowFromEndpoints(clientIP, serverIP)
+ require.NoError(t, err)
+ clientServerTransportFlow, err := gopacket.FlowFromEndpoints(clientPort, serverPort)
+ require.NoError(t, err)
+
+ database, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0))
+ require.NoError(t, err)
+
+ factory := NewBiDirectionalStreamFactory(wrapper.Storage, serverIP, &ruleManager)
+ version := wrapper.Storage.NewRowID()
+ ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
+ time.Sleep(10 * time.Millisecond)
+
+ serverStream := factory.New(serverClientNetFlow, serverClientTransportFlow)
+ connectionFlow := StreamFlow{clientIP, serverIP, clientPort, serverPort}
+ invertedConnectionFlow := StreamFlow{serverIP, clientIP, serverPort, clientPort}
+ connection, isPresent := factory.connections[invertedConnectionFlow]
+ require.True(t, isPresent)
+ assert.Equal(t, connectionFlow, connection.(*connectionHandlerImpl).connectionFlow)
+
+ serverStream.ReassemblyComplete()
+ assert.Equal(t, invertedConnectionFlow, connection.(*connectionHandlerImpl).otherStream.streamFlow)
+
+ clientStream := factory.New(clientServerNetFlow, clientServerTransportFlow)
+ assert.Len(t, factory.connections, 0)
+ clientStream.ReassemblyComplete()
+
+ var result Connection
+ err = wrapper.Storage.Find(Connections).Context(wrapper.Context).First(&result)
+ require.NoError(t, err)
+
+ assert.NotNil(t, result)
+ assert.Equal(t, wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), result.StartedAt), result.ID)
+ assert.Equal(t, clientIP.String(), result.SourceIP)
+ assert.Equal(t, serverIP.String(), result.DestinationIP)
+ assert.Equal(t, binary.BigEndian.Uint16(clientPort.Raw()), result.SourcePort)
+ assert.Equal(t, binary.BigEndian.Uint16(serverPort.Raw()), result.DestinationPort)
+
+ wrapper.Destroy(t)
+}
+
+
+type TestRuleManager struct {
+ databaseUpdated chan RulesDatabase
+}
+
+func (rm TestRuleManager) LoadRules() error {
+ return nil
+}
+
+func (rm TestRuleManager) AddRule(_ context.Context, _ Rule) (string, error) {
+ return "", nil
+}
+
+func (rm TestRuleManager) FillWithMatchedRules(_ *Connection, _ map[uint][]PatternSlice, _ map[uint][]PatternSlice) {
+}
+
+func (rm TestRuleManager) DatabaseUpdateChannel() chan RulesDatabase {
+ return rm.databaseUpdated
+}