diff options
-rw-r--r-- | pcap_importer.go | 73 | ||||
-rw-r--r-- | pcap_importer_test.go | 112 | ||||
-rw-r--r-- | test_data/icmp.pcap | bin | 0 -> 228024 bytes |
3 files changed, 120 insertions, 65 deletions
diff --git a/pcap_importer.go b/pcap_importer.go index 830ac38..bb09867 100644 --- a/pcap_importer.go +++ b/pcap_importer.go @@ -33,9 +33,9 @@ type ImportingSession struct { ProcessedPackets int `json:"processed_packets" bson:"processed_packets"` InvalidPackets int `json:"invalid_packets" bson:"invalid_packets"` PacketsPerService map[uint16]flowCount `json:"packets_per_service" bson:"packets_per_service"` - ImportingError error `json:"importing_error" bson:"importing_error,omitempty"` + ImportingError string `json:"importing_error" bson:"importing_error,omitempty"` cancelFunc context.CancelFunc - completed chan error + completed chan string } type flowCount [2]int @@ -77,16 +77,9 @@ func (pi *PcapImporter) ImportPcap(fileName string) (string, error) { ID: hash, PacketsPerService: make(map[uint16]flowCount), cancelFunc: cancelFunc, - completed: make(chan error), + completed: make(chan string), } - if result, err := pi.storage.Insert(ImportingSessions).Context(ctx).One(session); err != nil { - pi.mSessions.Unlock() - log.WithError(err).WithField("session", session).Panic("failed to insert a session into database") - } else if result == nil { - pi.mSessions.Unlock() - return hash, errors.New("pcap already processed") - } pi.sessions[hash] = session pi.mSessions.Unlock() @@ -115,34 +108,9 @@ func (pi *PcapImporter) CancelSession(sessionID string) error { // Read the pcap and save the tcp stream flow to the database func (pi *PcapImporter) parsePcap(session ImportingSession, fileName string, ctx context.Context) { - progressUpdate := func(completed bool, err error) { - if completed { - session.CompletedAt = time.Now() - } - session.ImportingError = err - - dupSession := session - dupSession.PacketsPerService = make(map[uint16]flowCount, len(session.PacketsPerService)) - for key, value := range session.PacketsPerService { - dupSession.PacketsPerService[key] = value - } - - pi.mSessions.Lock() - pi.sessions[session.ID] = dupSession - pi.mSessions.Unlock() - - if completed || err != nil { - if _, err = pi.storage.Update(ImportingSessions). - Filter(OrderedDocument{{"_id", session.ID}}).One(session); err != nil { - log.WithError(err).WithField("session", session).Error("failed to update importing stats") - } - session.completed <- err - } - } - handle, err := pcap.OpenOffline(fileName) if err != nil { - progressUpdate(false, errors.New("failed to process pcap")) + pi.progressUpdate(session, false, "failed to process pcap") log.WithError(err).WithFields(log.Fields{"session": session, "fileName": fileName}). Error("failed to open pcap") return @@ -160,7 +128,7 @@ func (pi *PcapImporter) parsePcap(session ImportingSession, fileName string, ctx case <-ctx.Done(): handle.Close() pi.releaseAssembler(assembler) - progressUpdate(false, errors.New("import process cancelled")) + pi.progressUpdate(session, false, "import process cancelled") return default: } @@ -173,7 +141,7 @@ func (pi *PcapImporter) parsePcap(session ImportingSession, fileName string, ctx } handle.Close() pi.releaseAssembler(assembler) - progressUpdate(true, nil) + pi.progressUpdate(session, true, "") return } @@ -182,12 +150,13 @@ func (pi *PcapImporter) parsePcap(session ImportingSession, fileName string, ctx firstPacketTime = timestamp } + session.ProcessedPackets++ + if packet.NetworkLayer() == nil || packet.TransportLayer() == nil || packet.TransportLayer().LayerType() != layers.LayerTypeTCP { // invalid packet session.InvalidPackets++ continue } - session.ProcessedPackets++ tcp := packet.TransportLayer().(*layers.TCP) var servicePort uint16 @@ -213,8 +182,32 @@ func (pi *PcapImporter) parsePcap(session ImportingSession, fileName string, ctx assembler.AssembleWithTimestamp(packet.NetworkLayer().NetworkFlow(), tcp, timestamp) case <-updateProgressInterval: - progressUpdate(false, nil) + pi.progressUpdate(session, false, "") + } + } +} + +func (pi *PcapImporter) progressUpdate(session ImportingSession, completed bool, err string) { + if completed { + session.CompletedAt = time.Now() + } + session.ImportingError = err + + packetsPerService := session.PacketsPerService + session.PacketsPerService = make(map[uint16]flowCount, len(packetsPerService)) + for key, value := range packetsPerService { + session.PacketsPerService[key] = value + } + + pi.mSessions.Lock() + pi.sessions[session.ID] = session + pi.mSessions.Unlock() + + if completed || session.ImportingError != "" { + if _, _err := pi.storage.Insert(ImportingSessions).One(session); _err != nil { + log.WithError(_err).WithField("session", session).Error("failed to insert importing stats") } + session.completed <- session.ImportingError } } diff --git a/pcap_importer_test.go b/pcap_importer_test.go index b38d2c9..bda2cb2 100644 --- a/pcap_importer_test.go +++ b/pcap_importer_test.go @@ -8,20 +8,86 @@ import ( "github.com/google/gopacket/tcpassembly/tcpreader" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/atomic" "net" "sync" "testing" + "time" ) func TestImportPcap(t *testing.T) { wrapper := NewTestStorageWrapper(t) + pcapImporter := newTestPcapImporter(wrapper, "172.17.0.3") + + pcapImporter.releaseAssembler(pcapImporter.takeAssembler()) + + sessionID, err := pcapImporter.ImportPcap("test_data/ping_pong_10000.pcap") + require.NoError(t, err) + + duplicateSessionID, err := pcapImporter.ImportPcap("test_data/ping_pong_10000.pcap") + require.Error(t, err) + assert.Equal(t, sessionID, duplicateSessionID) + + _, isPresent := pcapImporter.GetSession("invalid") + assert.False(t, isPresent) + + session := waitSessionCompletion(t, pcapImporter, sessionID) + assert.Equal(t, 15008, session.ProcessedPackets) + assert.Equal(t, 0, session.InvalidPackets) + assert.Equal(t, map[uint16]flowCount{9999: {10004, 5004}}, session.PacketsPerService) + assert.Zero(t, session.ImportingError) + + checkSessionEquals(t, wrapper, session) + + wrapper.Destroy(t) +} + +func TestCancelImportSession(t *testing.T) { + wrapper := NewTestStorageWrapper(t) + pcapImporter := newTestPcapImporter(wrapper, "172.17.0.3") + + sessionID, err := pcapImporter.ImportPcap("test_data/ping_pong_10000.pcap") + require.NoError(t, err) + + assert.Error(t, pcapImporter.CancelSession("invalid")) + assert.NoError(t, pcapImporter.CancelSession(sessionID)) + + session := waitSessionCompletion(t, pcapImporter, sessionID) + assert.Zero(t, session.CompletedAt) + assert.Equal(t, 0, session.ProcessedPackets) + assert.Equal(t, 0, session.InvalidPackets) + assert.Equal(t, map[uint16]flowCount{}, session.PacketsPerService) + assert.NotZero(t, session.ImportingError) + + checkSessionEquals(t, wrapper, session) + + wrapper.Destroy(t) +} + +func TestImportNoTcpPackets(t *testing.T) { + wrapper := NewTestStorageWrapper(t) + pcapImporter := newTestPcapImporter(wrapper, "172.17.0.4") + + sessionID, err := pcapImporter.ImportPcap("test_data/icmp.pcap") + require.NoError(t, err) + + session := waitSessionCompletion(t, pcapImporter, sessionID) + assert.Equal(t, 2000, session.ProcessedPackets) + assert.Equal(t, 2000, session.InvalidPackets) + assert.Equal(t, map[uint16]flowCount{}, session.PacketsPerService) + assert.Zero(t, session.ImportingError) + + checkSessionEquals(t, wrapper, session) + + wrapper.Destroy(t) +} + +func newTestPcapImporter(wrapper *TestStorageWrapper, serverIP string) *PcapImporter { wrapper.AddCollection(ImportingSessions) - serverEndpoint := layers.NewIPEndpoint(net.ParseIP("172.17.0.3")) + serverEndpoint := layers.NewIPEndpoint(net.ParseIP(serverIP)) streamPool := tcpassembly.NewStreamPool(&testStreamFactory{}) - pcapImporter := PcapImporter{ + return &PcapImporter{ storage: wrapper.Storage, streamPool: streamPool, assemblers: make([]*tcpassembly.Assembler, 0, initialAssemblerPoolSize), @@ -30,40 +96,36 @@ func TestImportPcap(t *testing.T) { mSessions: sync.Mutex{}, serverIP: serverEndpoint, } +} - sessionID, err := pcapImporter.ImportPcap("test_data/ping_pong_10000.pcap") - require.NoError(t, err) - assert.NotZero(t, sessionID) - - duplicateSessionID, err := pcapImporter.ImportPcap("test_data/ping_pong_10000.pcap") - require.Error(t, err) - assert.Equal(t, sessionID, duplicateSessionID) - - _, isPresent := pcapImporter.GetSession("invalid") - assert.False(t, isPresent) - +func waitSessionCompletion(t *testing.T, pcapImporter *PcapImporter, sessionID string) ImportingSession { session, isPresent := pcapImporter.GetSession(sessionID) require.True(t, isPresent) - err, _ = <- session.completed + <-session.completed session, isPresent = pcapImporter.GetSession(sessionID) - require.True(t, isPresent) - assert.NoError(t, err) + assert.True(t, isPresent) assert.Equal(t, sessionID, session.ID) - assert.Equal(t, 15008, session.ProcessedPackets) - assert.Equal(t, 0, session.InvalidPackets) - assert.Equal(t, map[uint16]flowCount{9999: {10004, 5004}}, session.PacketsPerService) - assert.NoError(t, session.ImportingError) - wrapper.Destroy(t) + return session +} + +func checkSessionEquals(t *testing.T, wrapper *TestStorageWrapper, session ImportingSession) { + var result ImportingSession + assert.NoError(t, wrapper.Storage.Find(ImportingSessions).Filter(OrderedDocument{{"_id", session.ID}}). + Context(wrapper.Context).First(&result)) + assert.Equal(t, session.CompletedAt.Unix(), result.CompletedAt.Unix()) + session.CompletedAt = time.Time{} + result.CompletedAt = time.Time{} + session.cancelFunc = nil + session.completed = nil + assert.Equal(t, session, result) } -type testStreamFactory struct{ - counter atomic.Int32 +type testStreamFactory struct { } func (sf *testStreamFactory) New(_, _ gopacket.Flow) tcpassembly.Stream { - sf.counter.Inc() reader := tcpreader.NewReaderStream() go func() { buffer := bufio.NewReader(&reader) diff --git a/test_data/icmp.pcap b/test_data/icmp.pcap Binary files differnew file mode 100644 index 0000000..3b6282c --- /dev/null +++ b/test_data/icmp.pcap |