diff options
-rw-r--r-- | connection_handler.go | 2 | ||||
-rw-r--r-- | connection_handler_test.go | 110 | ||||
-rw-r--r-- | connections.go | 2 | ||||
-rw-r--r-- | storage.go | 4 | ||||
-rw-r--r-- | stream_handler.go | 3 |
5 files changed, 74 insertions, 47 deletions
diff --git a/connection_handler.go b/connection_handler.go index dc23315..b8bddc9 100644 --- a/connection_handler.go +++ b/connection_handler.go @@ -197,8 +197,6 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) { DestinationPort: binary.BigEndian.Uint16(ch.connectionFlow[3].Raw()), StartedAt: startedAt, ClosedAt: closedAt, - ClientPackets: client.packetsCount, - ServerPackets: client.packetsCount, ClientBytes: client.streamLength, ServerBytes: server.streamLength, ClientDocuments: len(client.documentsIDs), diff --git a/connection_handler_test.go b/connection_handler_test.go index 51fa0bc..85aee0c 100644 --- a/connection_handler_test.go +++ b/connection_handler_test.go @@ -6,8 +6,10 @@ import ( "github.com/flier/gohs/hyperscan" "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/google/gopacket/tcpassembly" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "math/rand" "net" "testing" "time" @@ -28,12 +30,12 @@ func TestTakeReleaseScanners(t *testing.T) { ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} time.Sleep(10 * time.Millisecond) - n := 100 + n := 1000 for i := 0; i < n; i++ { scanner := factory.takeScanner() assert.Equal(t, scanner.version, version) - if i%5 == 0 { + if i%50 == 0 { version = wrapper.Storage.NewRowID() ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} time.Sleep(10 * time.Millisecond) @@ -47,26 +49,25 @@ func TestTakeReleaseScanners(t *testing.T) { scanners[i] = factory.takeScanner() assert.Equal(t, scanners[i].version, version) } - - version = wrapper.Storage.NewRowID() - ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} - time.Sleep(10 * time.Millisecond) - for i := 0; i < n; i++ { factory.releaseScanner(scanners[i]) } assert.Len(t, factory.scanners, n) + version = wrapper.Storage.NewRowID() + ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} + time.Sleep(10 * time.Millisecond) + for i := 0; i < n; i++ { scanners[i] = factory.takeScanner() assert.Equal(t, scanners[i].version, version) factory.releaseScanner(scanners[i]) } + close(ruleManager.DatabaseUpdateChannel()) wrapper.Destroy(t) } - func TestConnectionFactory(t *testing.T) { wrapper := NewTestStorageWrapper(t) wrapper.AddCollection(Connections) @@ -76,17 +77,12 @@ func TestConnectionFactory(t *testing.T) { databaseUpdated: make(chan RulesDatabase), } - serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP)) clientIP := layers.NewIPEndpoint(net.ParseIP(testSrcIP)) + serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP)) serverPort := layers.NewTCPPortEndpoint(dstPort) - clientPort := layers.NewTCPPortEndpoint(srcPort) - serverClientNetFlow, err := gopacket.FlowFromEndpoints(serverIP, clientIP) - require.NoError(t, err) - serverClientTransportFlow, err := gopacket.FlowFromEndpoints(serverPort, clientPort) - require.NoError(t, err) clientServerNetFlow, err := gopacket.FlowFromEndpoints(clientIP, serverIP) require.NoError(t, err) - clientServerTransportFlow, err := gopacket.FlowFromEndpoints(clientPort, serverPort) + serverClientNetFlow, err := gopacket.FlowFromEndpoints(serverIP, clientIP) require.NoError(t, err) database, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0)) @@ -97,35 +93,77 @@ func TestConnectionFactory(t *testing.T) { ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} time.Sleep(10 * time.Millisecond) - serverStream := factory.New(serverClientNetFlow, serverClientTransportFlow) - connectionFlow := StreamFlow{clientIP, serverIP, clientPort, serverPort} - invertedConnectionFlow := StreamFlow{serverIP, clientIP, serverPort, clientPort} - connection, isPresent := factory.connections[invertedConnectionFlow] - require.True(t, isPresent) - assert.Equal(t, connectionFlow, connection.(*connectionHandlerImpl).connectionFlow) + testInteraction := func(netFlow gopacket.Flow, transportFlow gopacket.Flow, otherSeenChan chan time.Time, + completed chan bool) { + + time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond) + stream := factory.New(netFlow, transportFlow) + seen := time.Now() + stream.Reassembled([]tcpassembly.Reassembly{{[]byte{}, 0, true, true, seen}}) + stream.ReassemblyComplete() + + var startedAt, closedAt time.Time + if netFlow == serverClientNetFlow { + otherSeenChan <- seen + return + } else { + otherSeen, ok := <-otherSeenChan + require.True(t, ok) + + if seen.Before(otherSeen) { + startedAt = seen + closedAt = otherSeen + } else { + startedAt = otherSeen + closedAt = seen + } + } + close(otherSeenChan) + + var result Connection + connectionFlow := StreamFlow{netFlow.Src(), netFlow.Dst(), transportFlow.Src(), transportFlow.Dst()} + connectionID := wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), startedAt) + err = wrapper.Storage.Find(Connections).Context(wrapper.Context). + Filter(OrderedDocument{{"_id", connectionID}}).First(&result) + require.NoError(t, err) + + assert.NotNil(t, result) + assert.Equal(t, wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), result.StartedAt), result.ID) + assert.Equal(t, netFlow.Src().String(), result.SourceIP) + assert.Equal(t, netFlow.Dst().String(), result.DestinationIP) + assert.Equal(t, binary.BigEndian.Uint16(transportFlow.Src().Raw()), result.SourcePort) + assert.Equal(t, binary.BigEndian.Uint16(transportFlow.Dst().Raw()), result.DestinationPort) + assert.Equal(t, startedAt.Unix(), result.StartedAt.Unix()) + assert.Equal(t, closedAt.Unix(), result.ClosedAt.Unix()) + + completed <- true + } - serverStream.ReassemblyComplete() - assert.Equal(t, invertedConnectionFlow, connection.(*connectionHandlerImpl).otherStream.streamFlow) + completed := make(chan bool) + n := 3000 - clientStream := factory.New(clientServerNetFlow, clientServerTransportFlow) - assert.Len(t, factory.connections, 0) - clientStream.ReassemblyComplete() + for port := 40000; port < 40000+n; port++ { + clientPort := layers.NewTCPPortEndpoint(layers.TCPPort(port)) + clientServerTransportFlow, err := gopacket.FlowFromEndpoints(clientPort, serverPort) + require.NoError(t, err) + serverClientTransportFlow, err := gopacket.FlowFromEndpoints(serverPort, clientPort) + require.NoError(t, err) - var result Connection - err = wrapper.Storage.Find(Connections).Context(wrapper.Context).First(&result) - require.NoError(t, err) + otherSeenChan := make(chan time.Time) + go testInteraction(clientServerNetFlow, clientServerTransportFlow, otherSeenChan, completed) + go testInteraction(serverClientNetFlow, serverClientTransportFlow, otherSeenChan, completed) + } + + for i := 0; i < n; i++ { + <-completed + } - assert.NotNil(t, result) - assert.Equal(t, wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), result.StartedAt), result.ID) - assert.Equal(t, clientIP.String(), result.SourceIP) - assert.Equal(t, serverIP.String(), result.DestinationIP) - assert.Equal(t, binary.BigEndian.Uint16(clientPort.Raw()), result.SourcePort) - assert.Equal(t, binary.BigEndian.Uint16(serverPort.Raw()), result.DestinationPort) + assert.Len(t, factory.connections, 0) + close(ruleManager.DatabaseUpdateChannel()) wrapper.Destroy(t) } - type TestRuleManager struct { databaseUpdated chan RulesDatabase } diff --git a/connections.go b/connections.go index 3c22a9a..380c8a1 100644 --- a/connections.go +++ b/connections.go @@ -10,8 +10,6 @@ type Connection struct { DestinationPort uint16 `json:"port_dst" bson:"port_dst"` StartedAt time.Time `json:"started_at" bson:"started_at"` ClosedAt time.Time `json:"closed_at" bson:"closed_at"` - ClientPackets int `json:"client_packets" bson:"client_packets"` - ServerPackets int `json:"server_packets" bson:"server_packets"` ClientBytes int `json:"client_bytes" bson:"client_bytes"` ServerBytes int `json:"server_bytes" bson:"server_bytes"` ClientDocuments int `json:"client_documents" bson:"client_documents"` @@ -64,10 +64,6 @@ func NewMongoStorage(uri string, port int, database string) *MongoStorage { } func (storage *MongoStorage) Connect(ctx context.Context) error { - if ctx == nil { - ctx, _ = context.WithTimeout(context.Background(), defaultConnectionTimeout) - } - return storage.client.Connect(ctx) } diff --git a/stream_handler.go b/stream_handler.go index f2309ad..78326c6 100644 --- a/stream_handler.go +++ b/stream_handler.go @@ -28,7 +28,6 @@ type StreamHandler struct { lastPacketSeen time.Time documentsIDs []RowID streamLength int - packetsCount int patternStream hyperscan.Stream patternMatches map[uint][]PatternSlice scanner Scanner @@ -60,8 +59,6 @@ func NewStreamHandler(connection ConnectionHandler, streamFlow StreamFlow, scann // Reassembled implements tcpassembly.Stream's Reassembled function. func (sh *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) { for _, r := range reassembly { - sh.packetsCount++ - skip := r.Skip if r.Start { skip = 0 |