aboutsummaryrefslogtreecommitdiff
path: root/connection_handler_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'connection_handler_test.go')
-rw-r--r--connection_handler_test.go110
1 files changed, 74 insertions, 36 deletions
diff --git a/connection_handler_test.go b/connection_handler_test.go
index 51fa0bc..85aee0c 100644
--- a/connection_handler_test.go
+++ b/connection_handler_test.go
@@ -6,8 +6,10 @@ import (
"github.com/flier/gohs/hyperscan"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
+ "github.com/google/gopacket/tcpassembly"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "math/rand"
"net"
"testing"
"time"
@@ -28,12 +30,12 @@ func TestTakeReleaseScanners(t *testing.T) {
ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
time.Sleep(10 * time.Millisecond)
- n := 100
+ n := 1000
for i := 0; i < n; i++ {
scanner := factory.takeScanner()
assert.Equal(t, scanner.version, version)
- if i%5 == 0 {
+ if i%50 == 0 {
version = wrapper.Storage.NewRowID()
ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
time.Sleep(10 * time.Millisecond)
@@ -47,26 +49,25 @@ func TestTakeReleaseScanners(t *testing.T) {
scanners[i] = factory.takeScanner()
assert.Equal(t, scanners[i].version, version)
}
-
- version = wrapper.Storage.NewRowID()
- ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
- time.Sleep(10 * time.Millisecond)
-
for i := 0; i < n; i++ {
factory.releaseScanner(scanners[i])
}
assert.Len(t, factory.scanners, n)
+ version = wrapper.Storage.NewRowID()
+ ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
+ time.Sleep(10 * time.Millisecond)
+
for i := 0; i < n; i++ {
scanners[i] = factory.takeScanner()
assert.Equal(t, scanners[i].version, version)
factory.releaseScanner(scanners[i])
}
+ close(ruleManager.DatabaseUpdateChannel())
wrapper.Destroy(t)
}
-
func TestConnectionFactory(t *testing.T) {
wrapper := NewTestStorageWrapper(t)
wrapper.AddCollection(Connections)
@@ -76,17 +77,12 @@ func TestConnectionFactory(t *testing.T) {
databaseUpdated: make(chan RulesDatabase),
}
- serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP))
clientIP := layers.NewIPEndpoint(net.ParseIP(testSrcIP))
+ serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP))
serverPort := layers.NewTCPPortEndpoint(dstPort)
- clientPort := layers.NewTCPPortEndpoint(srcPort)
- serverClientNetFlow, err := gopacket.FlowFromEndpoints(serverIP, clientIP)
- require.NoError(t, err)
- serverClientTransportFlow, err := gopacket.FlowFromEndpoints(serverPort, clientPort)
- require.NoError(t, err)
clientServerNetFlow, err := gopacket.FlowFromEndpoints(clientIP, serverIP)
require.NoError(t, err)
- clientServerTransportFlow, err := gopacket.FlowFromEndpoints(clientPort, serverPort)
+ serverClientNetFlow, err := gopacket.FlowFromEndpoints(serverIP, clientIP)
require.NoError(t, err)
database, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0))
@@ -97,35 +93,77 @@ func TestConnectionFactory(t *testing.T) {
ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
time.Sleep(10 * time.Millisecond)
- serverStream := factory.New(serverClientNetFlow, serverClientTransportFlow)
- connectionFlow := StreamFlow{clientIP, serverIP, clientPort, serverPort}
- invertedConnectionFlow := StreamFlow{serverIP, clientIP, serverPort, clientPort}
- connection, isPresent := factory.connections[invertedConnectionFlow]
- require.True(t, isPresent)
- assert.Equal(t, connectionFlow, connection.(*connectionHandlerImpl).connectionFlow)
+ testInteraction := func(netFlow gopacket.Flow, transportFlow gopacket.Flow, otherSeenChan chan time.Time,
+ completed chan bool) {
+
+ time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
+ stream := factory.New(netFlow, transportFlow)
+ seen := time.Now()
+ stream.Reassembled([]tcpassembly.Reassembly{{[]byte{}, 0, true, true, seen}})
+ stream.ReassemblyComplete()
+
+ var startedAt, closedAt time.Time
+ if netFlow == serverClientNetFlow {
+ otherSeenChan <- seen
+ return
+ } else {
+ otherSeen, ok := <-otherSeenChan
+ require.True(t, ok)
+
+ if seen.Before(otherSeen) {
+ startedAt = seen
+ closedAt = otherSeen
+ } else {
+ startedAt = otherSeen
+ closedAt = seen
+ }
+ }
+ close(otherSeenChan)
+
+ var result Connection
+ connectionFlow := StreamFlow{netFlow.Src(), netFlow.Dst(), transportFlow.Src(), transportFlow.Dst()}
+ connectionID := wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), startedAt)
+ err = wrapper.Storage.Find(Connections).Context(wrapper.Context).
+ Filter(OrderedDocument{{"_id", connectionID}}).First(&result)
+ require.NoError(t, err)
+
+ assert.NotNil(t, result)
+ assert.Equal(t, wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), result.StartedAt), result.ID)
+ assert.Equal(t, netFlow.Src().String(), result.SourceIP)
+ assert.Equal(t, netFlow.Dst().String(), result.DestinationIP)
+ assert.Equal(t, binary.BigEndian.Uint16(transportFlow.Src().Raw()), result.SourcePort)
+ assert.Equal(t, binary.BigEndian.Uint16(transportFlow.Dst().Raw()), result.DestinationPort)
+ assert.Equal(t, startedAt.Unix(), result.StartedAt.Unix())
+ assert.Equal(t, closedAt.Unix(), result.ClosedAt.Unix())
+
+ completed <- true
+ }
- serverStream.ReassemblyComplete()
- assert.Equal(t, invertedConnectionFlow, connection.(*connectionHandlerImpl).otherStream.streamFlow)
+ completed := make(chan bool)
+ n := 3000
- clientStream := factory.New(clientServerNetFlow, clientServerTransportFlow)
- assert.Len(t, factory.connections, 0)
- clientStream.ReassemblyComplete()
+ for port := 40000; port < 40000+n; port++ {
+ clientPort := layers.NewTCPPortEndpoint(layers.TCPPort(port))
+ clientServerTransportFlow, err := gopacket.FlowFromEndpoints(clientPort, serverPort)
+ require.NoError(t, err)
+ serverClientTransportFlow, err := gopacket.FlowFromEndpoints(serverPort, clientPort)
+ require.NoError(t, err)
- var result Connection
- err = wrapper.Storage.Find(Connections).Context(wrapper.Context).First(&result)
- require.NoError(t, err)
+ otherSeenChan := make(chan time.Time)
+ go testInteraction(clientServerNetFlow, clientServerTransportFlow, otherSeenChan, completed)
+ go testInteraction(serverClientNetFlow, serverClientTransportFlow, otherSeenChan, completed)
+ }
+
+ for i := 0; i < n; i++ {
+ <-completed
+ }
- assert.NotNil(t, result)
- assert.Equal(t, wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), result.StartedAt), result.ID)
- assert.Equal(t, clientIP.String(), result.SourceIP)
- assert.Equal(t, serverIP.String(), result.DestinationIP)
- assert.Equal(t, binary.BigEndian.Uint16(clientPort.Raw()), result.SourcePort)
- assert.Equal(t, binary.BigEndian.Uint16(serverPort.Raw()), result.DestinationPort)
+ assert.Len(t, factory.connections, 0)
+ close(ruleManager.DatabaseUpdateChannel())
wrapper.Destroy(t)
}
-
type TestRuleManager struct {
databaseUpdated chan RulesDatabase
}