aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--connection_handler.go2
-rw-r--r--connection_handler_test.go110
-rw-r--r--connections.go2
-rw-r--r--storage.go4
-rw-r--r--stream_handler.go3
5 files changed, 74 insertions, 47 deletions
diff --git a/connection_handler.go b/connection_handler.go
index dc23315..b8bddc9 100644
--- a/connection_handler.go
+++ b/connection_handler.go
@@ -197,8 +197,6 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) {
DestinationPort: binary.BigEndian.Uint16(ch.connectionFlow[3].Raw()),
StartedAt: startedAt,
ClosedAt: closedAt,
- ClientPackets: client.packetsCount,
- ServerPackets: client.packetsCount,
ClientBytes: client.streamLength,
ServerBytes: server.streamLength,
ClientDocuments: len(client.documentsIDs),
diff --git a/connection_handler_test.go b/connection_handler_test.go
index 51fa0bc..85aee0c 100644
--- a/connection_handler_test.go
+++ b/connection_handler_test.go
@@ -6,8 +6,10 @@ import (
"github.com/flier/gohs/hyperscan"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
+ "github.com/google/gopacket/tcpassembly"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "math/rand"
"net"
"testing"
"time"
@@ -28,12 +30,12 @@ func TestTakeReleaseScanners(t *testing.T) {
ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
time.Sleep(10 * time.Millisecond)
- n := 100
+ n := 1000
for i := 0; i < n; i++ {
scanner := factory.takeScanner()
assert.Equal(t, scanner.version, version)
- if i%5 == 0 {
+ if i%50 == 0 {
version = wrapper.Storage.NewRowID()
ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
time.Sleep(10 * time.Millisecond)
@@ -47,26 +49,25 @@ func TestTakeReleaseScanners(t *testing.T) {
scanners[i] = factory.takeScanner()
assert.Equal(t, scanners[i].version, version)
}
-
- version = wrapper.Storage.NewRowID()
- ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
- time.Sleep(10 * time.Millisecond)
-
for i := 0; i < n; i++ {
factory.releaseScanner(scanners[i])
}
assert.Len(t, factory.scanners, n)
+ version = wrapper.Storage.NewRowID()
+ ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
+ time.Sleep(10 * time.Millisecond)
+
for i := 0; i < n; i++ {
scanners[i] = factory.takeScanner()
assert.Equal(t, scanners[i].version, version)
factory.releaseScanner(scanners[i])
}
+ close(ruleManager.DatabaseUpdateChannel())
wrapper.Destroy(t)
}
-
func TestConnectionFactory(t *testing.T) {
wrapper := NewTestStorageWrapper(t)
wrapper.AddCollection(Connections)
@@ -76,17 +77,12 @@ func TestConnectionFactory(t *testing.T) {
databaseUpdated: make(chan RulesDatabase),
}
- serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP))
clientIP := layers.NewIPEndpoint(net.ParseIP(testSrcIP))
+ serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP))
serverPort := layers.NewTCPPortEndpoint(dstPort)
- clientPort := layers.NewTCPPortEndpoint(srcPort)
- serverClientNetFlow, err := gopacket.FlowFromEndpoints(serverIP, clientIP)
- require.NoError(t, err)
- serverClientTransportFlow, err := gopacket.FlowFromEndpoints(serverPort, clientPort)
- require.NoError(t, err)
clientServerNetFlow, err := gopacket.FlowFromEndpoints(clientIP, serverIP)
require.NoError(t, err)
- clientServerTransportFlow, err := gopacket.FlowFromEndpoints(clientPort, serverPort)
+ serverClientNetFlow, err := gopacket.FlowFromEndpoints(serverIP, clientIP)
require.NoError(t, err)
database, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0))
@@ -97,35 +93,77 @@ func TestConnectionFactory(t *testing.T) {
ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
time.Sleep(10 * time.Millisecond)
- serverStream := factory.New(serverClientNetFlow, serverClientTransportFlow)
- connectionFlow := StreamFlow{clientIP, serverIP, clientPort, serverPort}
- invertedConnectionFlow := StreamFlow{serverIP, clientIP, serverPort, clientPort}
- connection, isPresent := factory.connections[invertedConnectionFlow]
- require.True(t, isPresent)
- assert.Equal(t, connectionFlow, connection.(*connectionHandlerImpl).connectionFlow)
+ testInteraction := func(netFlow gopacket.Flow, transportFlow gopacket.Flow, otherSeenChan chan time.Time,
+ completed chan bool) {
+
+ time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
+ stream := factory.New(netFlow, transportFlow)
+ seen := time.Now()
+ stream.Reassembled([]tcpassembly.Reassembly{{[]byte{}, 0, true, true, seen}})
+ stream.ReassemblyComplete()
+
+ var startedAt, closedAt time.Time
+ if netFlow == serverClientNetFlow {
+ otherSeenChan <- seen
+ return
+ } else {
+ otherSeen, ok := <-otherSeenChan
+ require.True(t, ok)
+
+ if seen.Before(otherSeen) {
+ startedAt = seen
+ closedAt = otherSeen
+ } else {
+ startedAt = otherSeen
+ closedAt = seen
+ }
+ }
+ close(otherSeenChan)
+
+ var result Connection
+ connectionFlow := StreamFlow{netFlow.Src(), netFlow.Dst(), transportFlow.Src(), transportFlow.Dst()}
+ connectionID := wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), startedAt)
+ err = wrapper.Storage.Find(Connections).Context(wrapper.Context).
+ Filter(OrderedDocument{{"_id", connectionID}}).First(&result)
+ require.NoError(t, err)
+
+ assert.NotNil(t, result)
+ assert.Equal(t, wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), result.StartedAt), result.ID)
+ assert.Equal(t, netFlow.Src().String(), result.SourceIP)
+ assert.Equal(t, netFlow.Dst().String(), result.DestinationIP)
+ assert.Equal(t, binary.BigEndian.Uint16(transportFlow.Src().Raw()), result.SourcePort)
+ assert.Equal(t, binary.BigEndian.Uint16(transportFlow.Dst().Raw()), result.DestinationPort)
+ assert.Equal(t, startedAt.Unix(), result.StartedAt.Unix())
+ assert.Equal(t, closedAt.Unix(), result.ClosedAt.Unix())
+
+ completed <- true
+ }
- serverStream.ReassemblyComplete()
- assert.Equal(t, invertedConnectionFlow, connection.(*connectionHandlerImpl).otherStream.streamFlow)
+ completed := make(chan bool)
+ n := 3000
- clientStream := factory.New(clientServerNetFlow, clientServerTransportFlow)
- assert.Len(t, factory.connections, 0)
- clientStream.ReassemblyComplete()
+ for port := 40000; port < 40000+n; port++ {
+ clientPort := layers.NewTCPPortEndpoint(layers.TCPPort(port))
+ clientServerTransportFlow, err := gopacket.FlowFromEndpoints(clientPort, serverPort)
+ require.NoError(t, err)
+ serverClientTransportFlow, err := gopacket.FlowFromEndpoints(serverPort, clientPort)
+ require.NoError(t, err)
- var result Connection
- err = wrapper.Storage.Find(Connections).Context(wrapper.Context).First(&result)
- require.NoError(t, err)
+ otherSeenChan := make(chan time.Time)
+ go testInteraction(clientServerNetFlow, clientServerTransportFlow, otherSeenChan, completed)
+ go testInteraction(serverClientNetFlow, serverClientTransportFlow, otherSeenChan, completed)
+ }
+
+ for i := 0; i < n; i++ {
+ <-completed
+ }
- assert.NotNil(t, result)
- assert.Equal(t, wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), result.StartedAt), result.ID)
- assert.Equal(t, clientIP.String(), result.SourceIP)
- assert.Equal(t, serverIP.String(), result.DestinationIP)
- assert.Equal(t, binary.BigEndian.Uint16(clientPort.Raw()), result.SourcePort)
- assert.Equal(t, binary.BigEndian.Uint16(serverPort.Raw()), result.DestinationPort)
+ assert.Len(t, factory.connections, 0)
+ close(ruleManager.DatabaseUpdateChannel())
wrapper.Destroy(t)
}
-
type TestRuleManager struct {
databaseUpdated chan RulesDatabase
}
diff --git a/connections.go b/connections.go
index 3c22a9a..380c8a1 100644
--- a/connections.go
+++ b/connections.go
@@ -10,8 +10,6 @@ type Connection struct {
DestinationPort uint16 `json:"port_dst" bson:"port_dst"`
StartedAt time.Time `json:"started_at" bson:"started_at"`
ClosedAt time.Time `json:"closed_at" bson:"closed_at"`
- ClientPackets int `json:"client_packets" bson:"client_packets"`
- ServerPackets int `json:"server_packets" bson:"server_packets"`
ClientBytes int `json:"client_bytes" bson:"client_bytes"`
ServerBytes int `json:"server_bytes" bson:"server_bytes"`
ClientDocuments int `json:"client_documents" bson:"client_documents"`
diff --git a/storage.go b/storage.go
index b88c5c8..7d98ba0 100644
--- a/storage.go
+++ b/storage.go
@@ -64,10 +64,6 @@ func NewMongoStorage(uri string, port int, database string) *MongoStorage {
}
func (storage *MongoStorage) Connect(ctx context.Context) error {
- if ctx == nil {
- ctx, _ = context.WithTimeout(context.Background(), defaultConnectionTimeout)
- }
-
return storage.client.Connect(ctx)
}
diff --git a/stream_handler.go b/stream_handler.go
index f2309ad..78326c6 100644
--- a/stream_handler.go
+++ b/stream_handler.go
@@ -28,7 +28,6 @@ type StreamHandler struct {
lastPacketSeen time.Time
documentsIDs []RowID
streamLength int
- packetsCount int
patternStream hyperscan.Stream
patternMatches map[uint][]PatternSlice
scanner Scanner
@@ -60,8 +59,6 @@ func NewStreamHandler(connection ConnectionHandler, streamFlow StreamFlow, scann
// Reassembled implements tcpassembly.Stream's Reassembled function.
func (sh *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) {
for _, r := range reassembly {
- sh.packetsCount++
-
skip := r.Skip
if r.Start {
skip = 0