diff options
Diffstat (limited to 'connection_handler_test.go')
-rw-r--r-- | connection_handler_test.go | 110 |
1 files changed, 74 insertions, 36 deletions
diff --git a/connection_handler_test.go b/connection_handler_test.go index 51fa0bc..85aee0c 100644 --- a/connection_handler_test.go +++ b/connection_handler_test.go @@ -6,8 +6,10 @@ import ( "github.com/flier/gohs/hyperscan" "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/google/gopacket/tcpassembly" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "math/rand" "net" "testing" "time" @@ -28,12 +30,12 @@ func TestTakeReleaseScanners(t *testing.T) { ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} time.Sleep(10 * time.Millisecond) - n := 100 + n := 1000 for i := 0; i < n; i++ { scanner := factory.takeScanner() assert.Equal(t, scanner.version, version) - if i%5 == 0 { + if i%50 == 0 { version = wrapper.Storage.NewRowID() ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} time.Sleep(10 * time.Millisecond) @@ -47,26 +49,25 @@ func TestTakeReleaseScanners(t *testing.T) { scanners[i] = factory.takeScanner() assert.Equal(t, scanners[i].version, version) } - - version = wrapper.Storage.NewRowID() - ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} - time.Sleep(10 * time.Millisecond) - for i := 0; i < n; i++ { factory.releaseScanner(scanners[i]) } assert.Len(t, factory.scanners, n) + version = wrapper.Storage.NewRowID() + ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} + time.Sleep(10 * time.Millisecond) + for i := 0; i < n; i++ { scanners[i] = factory.takeScanner() assert.Equal(t, scanners[i].version, version) factory.releaseScanner(scanners[i]) } + close(ruleManager.DatabaseUpdateChannel()) wrapper.Destroy(t) } - func TestConnectionFactory(t *testing.T) { wrapper := NewTestStorageWrapper(t) wrapper.AddCollection(Connections) @@ -76,17 +77,12 @@ func TestConnectionFactory(t *testing.T) { databaseUpdated: make(chan RulesDatabase), } - serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP)) clientIP := layers.NewIPEndpoint(net.ParseIP(testSrcIP)) + serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP)) serverPort := layers.NewTCPPortEndpoint(dstPort) - clientPort := layers.NewTCPPortEndpoint(srcPort) - serverClientNetFlow, err := gopacket.FlowFromEndpoints(serverIP, clientIP) - require.NoError(t, err) - serverClientTransportFlow, err := gopacket.FlowFromEndpoints(serverPort, clientPort) - require.NoError(t, err) clientServerNetFlow, err := gopacket.FlowFromEndpoints(clientIP, serverIP) require.NoError(t, err) - clientServerTransportFlow, err := gopacket.FlowFromEndpoints(clientPort, serverPort) + serverClientNetFlow, err := gopacket.FlowFromEndpoints(serverIP, clientIP) require.NoError(t, err) database, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0)) @@ -97,35 +93,77 @@ func TestConnectionFactory(t *testing.T) { ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} time.Sleep(10 * time.Millisecond) - serverStream := factory.New(serverClientNetFlow, serverClientTransportFlow) - connectionFlow := StreamFlow{clientIP, serverIP, clientPort, serverPort} - invertedConnectionFlow := StreamFlow{serverIP, clientIP, serverPort, clientPort} - connection, isPresent := factory.connections[invertedConnectionFlow] - require.True(t, isPresent) - assert.Equal(t, connectionFlow, connection.(*connectionHandlerImpl).connectionFlow) + testInteraction := func(netFlow gopacket.Flow, transportFlow gopacket.Flow, otherSeenChan chan time.Time, + completed chan bool) { + + time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond) + stream := factory.New(netFlow, transportFlow) + seen := time.Now() + stream.Reassembled([]tcpassembly.Reassembly{{[]byte{}, 0, true, true, seen}}) + stream.ReassemblyComplete() + + var startedAt, closedAt time.Time + if netFlow == serverClientNetFlow { + otherSeenChan <- seen + return + } else { + otherSeen, ok := <-otherSeenChan + require.True(t, ok) + + if seen.Before(otherSeen) { + startedAt = seen + closedAt = otherSeen + } else { + startedAt = otherSeen + closedAt = seen + } + } + close(otherSeenChan) + + var result Connection + connectionFlow := StreamFlow{netFlow.Src(), netFlow.Dst(), transportFlow.Src(), transportFlow.Dst()} + connectionID := wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), startedAt) + err = wrapper.Storage.Find(Connections).Context(wrapper.Context). + Filter(OrderedDocument{{"_id", connectionID}}).First(&result) + require.NoError(t, err) + + assert.NotNil(t, result) + assert.Equal(t, wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), result.StartedAt), result.ID) + assert.Equal(t, netFlow.Src().String(), result.SourceIP) + assert.Equal(t, netFlow.Dst().String(), result.DestinationIP) + assert.Equal(t, binary.BigEndian.Uint16(transportFlow.Src().Raw()), result.SourcePort) + assert.Equal(t, binary.BigEndian.Uint16(transportFlow.Dst().Raw()), result.DestinationPort) + assert.Equal(t, startedAt.Unix(), result.StartedAt.Unix()) + assert.Equal(t, closedAt.Unix(), result.ClosedAt.Unix()) + + completed <- true + } - serverStream.ReassemblyComplete() - assert.Equal(t, invertedConnectionFlow, connection.(*connectionHandlerImpl).otherStream.streamFlow) + completed := make(chan bool) + n := 3000 - clientStream := factory.New(clientServerNetFlow, clientServerTransportFlow) - assert.Len(t, factory.connections, 0) - clientStream.ReassemblyComplete() + for port := 40000; port < 40000+n; port++ { + clientPort := layers.NewTCPPortEndpoint(layers.TCPPort(port)) + clientServerTransportFlow, err := gopacket.FlowFromEndpoints(clientPort, serverPort) + require.NoError(t, err) + serverClientTransportFlow, err := gopacket.FlowFromEndpoints(serverPort, clientPort) + require.NoError(t, err) - var result Connection - err = wrapper.Storage.Find(Connections).Context(wrapper.Context).First(&result) - require.NoError(t, err) + otherSeenChan := make(chan time.Time) + go testInteraction(clientServerNetFlow, clientServerTransportFlow, otherSeenChan, completed) + go testInteraction(serverClientNetFlow, serverClientTransportFlow, otherSeenChan, completed) + } + + for i := 0; i < n; i++ { + <-completed + } - assert.NotNil(t, result) - assert.Equal(t, wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), result.StartedAt), result.ID) - assert.Equal(t, clientIP.String(), result.SourceIP) - assert.Equal(t, serverIP.String(), result.DestinationIP) - assert.Equal(t, binary.BigEndian.Uint16(clientPort.Raw()), result.SourcePort) - assert.Equal(t, binary.BigEndian.Uint16(serverPort.Raw()), result.DestinationPort) + assert.Len(t, factory.connections, 0) + close(ruleManager.DatabaseUpdateChannel()) wrapper.Destroy(t) } - type TestRuleManager struct { databaseUpdated chan RulesDatabase } |