diff options
-rw-r--r-- | stream_handler.go | 4 | ||||
-rw-r--r-- | stream_handler_test.go | 211 |
2 files changed, 196 insertions, 19 deletions
diff --git a/stream_handler.go b/stream_handler.go index 80d91d6..ce580fc 100644 --- a/stream_handler.go +++ b/stream_handler.go @@ -135,7 +135,7 @@ func (sh *StreamHandler) onMatch(id uint, from uint64, to uint64, flags uint, co // new from == new match sh.patternMatches[id] = append(patternSlices, PatternSlice{from, to}) } else { - patternSlices = make([]PatternSlice, InitialPatternSliceSize) + patternSlices = make([]PatternSlice, 1, InitialPatternSliceSize) patternSlices[0] = PatternSlice{from, to} sh.patternMatches[id] = patternSlices } @@ -169,7 +169,7 @@ func (sh *StreamHandler) generateDocumentKey() string { endpointsHash := sh.streamKey[0].FastHash() ^ sh.streamKey[1].FastHash() ^ sh.streamKey[2].FastHash() ^ sh.streamKey[3].FastHash() binary.BigEndian.PutUint64(hash, endpointsHash) - binary.BigEndian.PutUint64(hash[8:], uint64(sh.timestamps[0].UnixNano())) + binary.BigEndian.PutUint64(hash[8:], uint64(sh.firstPacketSeen.UnixNano())) binary.BigEndian.PutUint16(hash[8:], uint16(len(sh.documentsKeys))) return fmt.Sprintf("%x", hash) diff --git a/stream_handler_test.go b/stream_handler_test.go index a1004dc..425c1b7 100644 --- a/stream_handler_test.go +++ b/stream_handler_test.go @@ -3,6 +3,7 @@ package main import ( "context" "errors" + "fmt" "github.com/flier/gohs/hyperscan" "github.com/google/gopacket/layers" "github.com/google/gopacket/tcpassembly" @@ -23,9 +24,16 @@ const dstPort = 8080 func TestReassemblingEmptyStream(t *testing.T) { patterns, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0)) require.Nil(t, err) - streamHandler := createTestStreamHandler(t, testStorage{}, patterns) + scratch, err := hyperscan.NewScratch(patterns) + require.Nil(t, err) + streamHandler := createTestStreamHandler(testStorage{}, patterns, scratch) - streamHandler.Reassembled([]tcpassembly.Reassembly{}) + streamHandler.Reassembled([]tcpassembly.Reassembly{{ + Bytes: []byte{}, + Skip: 0, + Start: true, + End: true, + }}) assert.Len(t, streamHandler.indexes, 0, "indexes") assert.Len(t, streamHandler.timestamps, 0, "timestamps") assert.Len(t, streamHandler.lossBlocks, 0) @@ -42,18 +50,26 @@ func TestReassemblingEmptyStream(t *testing.T) { } streamHandler.ReassemblyComplete() assert.Equal(t, 42, expected) + + err = scratch.Free() + require.Nil(t, err, "free scratch") + err = patterns.Close() + require.Nil(t, err, "close stream database") } -func TestReassemblingSingleDocumentStream(t *testing.T) { + +func TestReassemblingSingleDocument(t *testing.T) { patterns, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/impossible_to_match/", 0)) require.Nil(t, err) + scratch, err := hyperscan.NewScratch(patterns) + require.Nil(t, err) storage := &testStorage{} - streamHandler := createTestStreamHandler(t, storage, patterns) + streamHandler := createTestStreamHandler(storage, patterns, scratch) payloadLen := 256 - firstTime := time.Unix(1000000000, 0) - middleTime := time.Unix(1000000010, 0) - lastTime := time.Unix(1000000020, 0) + firstTime := time.Unix(0, 0) + middleTime := time.Unix(10, 0) + lastTime := time.Unix(20, 0) data := make([]byte, MaxDocumentSize) rand.Read(data) reassembles := make([]tcpassembly.Reassembly, MaxDocumentSize / payloadLen) @@ -85,7 +101,7 @@ func TestReassemblingSingleDocumentStream(t *testing.T) { storage.insertFunc = func(ctx context.Context, collectionName string, document interface{}) (i interface{}, err error) { od := document.(OrderedDocument) assert.Equal(t, "connection_streams", collectionName) - assert.Equal(t, "bb41a60281cfae830000b6b3a7640000", od[0].Value) + assert.Equal(t, "bb41a60281cfae830000000000000000", od[0].Value) assert.Equal(t, nil, od[1].Value) assert.Equal(t, 0, od[2].Value) assert.Equal(t, data, od[3].Value) @@ -102,37 +118,198 @@ func TestReassemblingSingleDocumentStream(t *testing.T) { inserted = false } - assert.Equal(t, data, streamHandler.buffer.Bytes(), "buffer should contains the same bytes of reassembles") - assert.Equal(t, indexes, streamHandler.indexes, "indexes") - assert.Equal(t, timestamps, streamHandler.timestamps, "timestamps") - assert.Equal(t, lossBlocks, streamHandler.lossBlocks, "lossBlocks") + completed := false + streamHandler.connection.(*testConnectionHandler).onComplete = func(handler *StreamHandler) { + completed = true + } + streamHandler.ReassemblyComplete() + assert.Equal(t, len(data), streamHandler.currentIndex) assert.Equal(t, firstTime, streamHandler.firstPacketSeen) assert.Equal(t, lastTime, streamHandler.lastPacketSeen) - assert.Len(t, streamHandler.documentsKeys, 0) + assert.Len(t, streamHandler.documentsKeys, 1) assert.Equal(t, len(data), streamHandler.streamLength) assert.Len(t, streamHandler.patternMatches, 0) + assert.Equal(t, true, inserted, "inserted") + assert.Equal(t, true, completed, "completed") + + err = scratch.Free() + require.Nil(t, err, "free scratch") + err = patterns.Close() + require.Nil(t, err, "close stream database") +} + + +func TestReassemblingMultipleDocuments(t *testing.T) { + patterns, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/impossible_to_match/", 0)) + require.Nil(t, err) + scratch, err := hyperscan.NewScratch(patterns) + require.Nil(t, err) + storage := &testStorage{} + streamHandler := createTestStreamHandler(storage, patterns, scratch) + + payloadLen := 256 + firstTime := time.Unix(0, 0) + middleTime := time.Unix(10, 0) + lastTime := time.Unix(20, 0) + dataSize := MaxDocumentSize*2 + data := make([]byte, dataSize) + rand.Read(data) + reassembles := make([]tcpassembly.Reassembly, dataSize / payloadLen) + indexes := make([]int, dataSize / payloadLen) + timestamps := make([]time.Time, dataSize / payloadLen) + lossBlocks := make([]bool, dataSize / payloadLen) + for i := 0; i < len(reassembles); i++ { + var seen time.Time + if i == 0 { + seen = firstTime + } else if i == len(reassembles)-1 { + seen = lastTime + } else { + seen = middleTime + } + + reassembles[i] = tcpassembly.Reassembly{ + Bytes: data[i*payloadLen:(i+1)*payloadLen], + Skip: 0, + Start: i == 0, + End: i == len(reassembles)-1, + Seen: seen, + } + indexes[i] = i*payloadLen % MaxDocumentSize + timestamps[i] = seen + } + + inserted := 0 + storage.insertFunc = func(ctx context.Context, collectionName string, document interface{}) (i interface{}, err error) { + od := document.(OrderedDocument) + blockLen := MaxDocumentSize / payloadLen + assert.Equal(t, "connection_streams", collectionName) + assert.Equal(t, fmt.Sprintf("bb41a60281cfae83000%v000000000000", inserted), od[0].Value) + assert.Equal(t, nil, od[1].Value) + assert.Equal(t, inserted, od[2].Value) + assert.Equal(t, data[MaxDocumentSize*inserted:MaxDocumentSize*(inserted+1)], od[3].Value) + assert.Equal(t, indexes[blockLen*inserted:blockLen*(inserted+1)], od[4].Value) + assert.Equal(t, timestamps[blockLen*inserted:blockLen*(inserted+1)], od[5].Value) + assert.Equal(t, lossBlocks[blockLen*inserted:blockLen*(inserted+1)], od[6].Value) + assert.Len(t, od[7].Value, 0) + inserted += 1 + + return nil, nil + } + + streamHandler.Reassembled(reassembles) + if !assert.Equal(t, 1, inserted) { + inserted = 1 + } + completed := false streamHandler.connection.(*testConnectionHandler).onComplete = func(handler *StreamHandler) { completed = true } streamHandler.ReassemblyComplete() + + assert.Equal(t, MaxDocumentSize, streamHandler.currentIndex) + assert.Equal(t, firstTime, streamHandler.firstPacketSeen) + assert.Equal(t, lastTime, streamHandler.lastPacketSeen) + assert.Len(t, streamHandler.documentsKeys, 2) + assert.Equal(t, len(data), streamHandler.streamLength) + assert.Len(t, streamHandler.patternMatches, 0) + + assert.Equal(t, 2, inserted, "inserted") + assert.Equal(t, true, completed, "completed") + + err = scratch.Free() + require.Nil(t, err, "free scratch") + err = patterns.Close() + require.Nil(t, err, "close stream database") +} + +func TestReassemblingPatternMatching(t *testing.T) { + a, err := hyperscan.ParsePattern("/a{8}/i") + require.Nil(t, err) + a.Id = 0 + a.Flags |= hyperscan.SomLeftMost + b, err := hyperscan.ParsePattern("/b[c]+b/i") + require.Nil(t, err) + b.Id = 1 + b.Flags |= hyperscan.SomLeftMost + d, err := hyperscan.ParsePattern("/[d]+e[d]+/i") + require.Nil(t, err) + d.Id = 2 + d.Flags |= hyperscan.SomLeftMost + + payload := "aaaaaaaa0aaaaaaaaaa0bbbcccbbb0dddeddddedddd" + expected := map[uint][]PatternSlice{ + 0: {{0, 8}, {9, 17}, {10, 18}, {11, 19}}, + 1: {{22, 27}}, + 2: {{30, 38}, {34, 43}}, + } + + patterns, err := hyperscan.NewStreamDatabase(a, b, d) + require.Nil(t, err) + scratch, err := hyperscan.NewScratch(patterns) + require.Nil(t, err) + storage := &testStorage{} + streamHandler := createTestStreamHandler(storage, patterns, scratch) + + seen := time.Unix(0, 0) + inserted := false + storage.insertFunc = func(ctx context.Context, collectionName string, document interface{}) (i interface{}, err error) { + od := document.(OrderedDocument) + assert.Equal(t, "connection_streams", collectionName) + assert.Equal(t, "bb41a60281cfae830000000000000000", od[0].Value) + assert.Equal(t, nil, od[1].Value) + assert.Equal(t, 0, od[2].Value) + assert.Equal(t, []byte(payload), od[3].Value) + assert.Equal(t, []int{0}, od[4].Value) + assert.Equal(t, []time.Time{seen}, od[5].Value) + assert.Equal(t, []bool{false}, od[6].Value) + assert.Equal(t, expected, od[7].Value) + inserted = true + + return nil, nil + } + + streamHandler.Reassembled([]tcpassembly.Reassembly{{ + Bytes: []byte(payload), + Skip: 0, + Start: true, + End: true, + Seen: seen, + }}) + assert.Equal(t, false, inserted) + + completed := false + streamHandler.connection.(*testConnectionHandler).onComplete = func(handler *StreamHandler) { + completed = true + } + streamHandler.ReassemblyComplete() + + assert.Equal(t, len(payload), streamHandler.currentIndex) + assert.Equal(t, seen, streamHandler.firstPacketSeen) + assert.Equal(t, seen, streamHandler.lastPacketSeen) + assert.Len(t, streamHandler.documentsKeys, 1) + assert.Equal(t, len(payload), streamHandler.streamLength) + assert.Equal(t, true, inserted, "inserted") assert.Equal(t, true, completed, "completed") + + err = scratch.Free() + require.Nil(t, err, "free scratch") + err = patterns.Close() + require.Nil(t, err, "close stream database") } -func createTestStreamHandler(t *testing.T, storage Storage, patterns hyperscan.StreamDatabase) StreamHandler { +func createTestStreamHandler(storage Storage, patterns hyperscan.StreamDatabase, scratch *hyperscan.Scratch) StreamHandler { testConnectionHandler := &testConnectionHandler{ storage: storage, context: context.Background(), patterns: patterns, } - scratch, err := hyperscan.NewScratch(patterns) - require.Nil(t, err) - srcIp := layers.NewIPEndpoint(net.ParseIP(testSrcIp)) dstIp := layers.NewIPEndpoint(net.ParseIP(testDstIp)) srcPort := layers.NewTCPPortEndpoint(srcPort) |