aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pcap_importer.go73
-rw-r--r--pcap_importer_test.go112
-rw-r--r--test_data/icmp.pcapbin0 -> 228024 bytes
3 files changed, 120 insertions, 65 deletions
diff --git a/pcap_importer.go b/pcap_importer.go
index 830ac38..bb09867 100644
--- a/pcap_importer.go
+++ b/pcap_importer.go
@@ -33,9 +33,9 @@ type ImportingSession struct {
ProcessedPackets int `json:"processed_packets" bson:"processed_packets"`
InvalidPackets int `json:"invalid_packets" bson:"invalid_packets"`
PacketsPerService map[uint16]flowCount `json:"packets_per_service" bson:"packets_per_service"`
- ImportingError error `json:"importing_error" bson:"importing_error,omitempty"`
+ ImportingError string `json:"importing_error" bson:"importing_error,omitempty"`
cancelFunc context.CancelFunc
- completed chan error
+ completed chan string
}
type flowCount [2]int
@@ -77,16 +77,9 @@ func (pi *PcapImporter) ImportPcap(fileName string) (string, error) {
ID: hash,
PacketsPerService: make(map[uint16]flowCount),
cancelFunc: cancelFunc,
- completed: make(chan error),
+ completed: make(chan string),
}
- if result, err := pi.storage.Insert(ImportingSessions).Context(ctx).One(session); err != nil {
- pi.mSessions.Unlock()
- log.WithError(err).WithField("session", session).Panic("failed to insert a session into database")
- } else if result == nil {
- pi.mSessions.Unlock()
- return hash, errors.New("pcap already processed")
- }
pi.sessions[hash] = session
pi.mSessions.Unlock()
@@ -115,34 +108,9 @@ func (pi *PcapImporter) CancelSession(sessionID string) error {
// Read the pcap and save the tcp stream flow to the database
func (pi *PcapImporter) parsePcap(session ImportingSession, fileName string, ctx context.Context) {
- progressUpdate := func(completed bool, err error) {
- if completed {
- session.CompletedAt = time.Now()
- }
- session.ImportingError = err
-
- dupSession := session
- dupSession.PacketsPerService = make(map[uint16]flowCount, len(session.PacketsPerService))
- for key, value := range session.PacketsPerService {
- dupSession.PacketsPerService[key] = value
- }
-
- pi.mSessions.Lock()
- pi.sessions[session.ID] = dupSession
- pi.mSessions.Unlock()
-
- if completed || err != nil {
- if _, err = pi.storage.Update(ImportingSessions).
- Filter(OrderedDocument{{"_id", session.ID}}).One(session); err != nil {
- log.WithError(err).WithField("session", session).Error("failed to update importing stats")
- }
- session.completed <- err
- }
- }
-
handle, err := pcap.OpenOffline(fileName)
if err != nil {
- progressUpdate(false, errors.New("failed to process pcap"))
+ pi.progressUpdate(session, false, "failed to process pcap")
log.WithError(err).WithFields(log.Fields{"session": session, "fileName": fileName}).
Error("failed to open pcap")
return
@@ -160,7 +128,7 @@ func (pi *PcapImporter) parsePcap(session ImportingSession, fileName string, ctx
case <-ctx.Done():
handle.Close()
pi.releaseAssembler(assembler)
- progressUpdate(false, errors.New("import process cancelled"))
+ pi.progressUpdate(session, false, "import process cancelled")
return
default:
}
@@ -173,7 +141,7 @@ func (pi *PcapImporter) parsePcap(session ImportingSession, fileName string, ctx
}
handle.Close()
pi.releaseAssembler(assembler)
- progressUpdate(true, nil)
+ pi.progressUpdate(session, true, "")
return
}
@@ -182,12 +150,13 @@ func (pi *PcapImporter) parsePcap(session ImportingSession, fileName string, ctx
firstPacketTime = timestamp
}
+ session.ProcessedPackets++
+
if packet.NetworkLayer() == nil || packet.TransportLayer() == nil ||
packet.TransportLayer().LayerType() != layers.LayerTypeTCP { // invalid packet
session.InvalidPackets++
continue
}
- session.ProcessedPackets++
tcp := packet.TransportLayer().(*layers.TCP)
var servicePort uint16
@@ -213,8 +182,32 @@ func (pi *PcapImporter) parsePcap(session ImportingSession, fileName string, ctx
assembler.AssembleWithTimestamp(packet.NetworkLayer().NetworkFlow(), tcp, timestamp)
case <-updateProgressInterval:
- progressUpdate(false, nil)
+ pi.progressUpdate(session, false, "")
+ }
+ }
+}
+
+func (pi *PcapImporter) progressUpdate(session ImportingSession, completed bool, err string) {
+ if completed {
+ session.CompletedAt = time.Now()
+ }
+ session.ImportingError = err
+
+ packetsPerService := session.PacketsPerService
+ session.PacketsPerService = make(map[uint16]flowCount, len(packetsPerService))
+ for key, value := range packetsPerService {
+ session.PacketsPerService[key] = value
+ }
+
+ pi.mSessions.Lock()
+ pi.sessions[session.ID] = session
+ pi.mSessions.Unlock()
+
+ if completed || session.ImportingError != "" {
+ if _, _err := pi.storage.Insert(ImportingSessions).One(session); _err != nil {
+ log.WithError(_err).WithField("session", session).Error("failed to insert importing stats")
}
+ session.completed <- session.ImportingError
}
}
diff --git a/pcap_importer_test.go b/pcap_importer_test.go
index b38d2c9..bda2cb2 100644
--- a/pcap_importer_test.go
+++ b/pcap_importer_test.go
@@ -8,20 +8,86 @@ import (
"github.com/google/gopacket/tcpassembly/tcpreader"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "go.uber.org/atomic"
"net"
"sync"
"testing"
+ "time"
)
func TestImportPcap(t *testing.T) {
wrapper := NewTestStorageWrapper(t)
+ pcapImporter := newTestPcapImporter(wrapper, "172.17.0.3")
+
+ pcapImporter.releaseAssembler(pcapImporter.takeAssembler())
+
+ sessionID, err := pcapImporter.ImportPcap("test_data/ping_pong_10000.pcap")
+ require.NoError(t, err)
+
+ duplicateSessionID, err := pcapImporter.ImportPcap("test_data/ping_pong_10000.pcap")
+ require.Error(t, err)
+ assert.Equal(t, sessionID, duplicateSessionID)
+
+ _, isPresent := pcapImporter.GetSession("invalid")
+ assert.False(t, isPresent)
+
+ session := waitSessionCompletion(t, pcapImporter, sessionID)
+ assert.Equal(t, 15008, session.ProcessedPackets)
+ assert.Equal(t, 0, session.InvalidPackets)
+ assert.Equal(t, map[uint16]flowCount{9999: {10004, 5004}}, session.PacketsPerService)
+ assert.Zero(t, session.ImportingError)
+
+ checkSessionEquals(t, wrapper, session)
+
+ wrapper.Destroy(t)
+}
+
+func TestCancelImportSession(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ pcapImporter := newTestPcapImporter(wrapper, "172.17.0.3")
+
+ sessionID, err := pcapImporter.ImportPcap("test_data/ping_pong_10000.pcap")
+ require.NoError(t, err)
+
+ assert.Error(t, pcapImporter.CancelSession("invalid"))
+ assert.NoError(t, pcapImporter.CancelSession(sessionID))
+
+ session := waitSessionCompletion(t, pcapImporter, sessionID)
+ assert.Zero(t, session.CompletedAt)
+ assert.Equal(t, 0, session.ProcessedPackets)
+ assert.Equal(t, 0, session.InvalidPackets)
+ assert.Equal(t, map[uint16]flowCount{}, session.PacketsPerService)
+ assert.NotZero(t, session.ImportingError)
+
+ checkSessionEquals(t, wrapper, session)
+
+ wrapper.Destroy(t)
+}
+
+func TestImportNoTcpPackets(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ pcapImporter := newTestPcapImporter(wrapper, "172.17.0.4")
+
+ sessionID, err := pcapImporter.ImportPcap("test_data/icmp.pcap")
+ require.NoError(t, err)
+
+ session := waitSessionCompletion(t, pcapImporter, sessionID)
+ assert.Equal(t, 2000, session.ProcessedPackets)
+ assert.Equal(t, 2000, session.InvalidPackets)
+ assert.Equal(t, map[uint16]flowCount{}, session.PacketsPerService)
+ assert.Zero(t, session.ImportingError)
+
+ checkSessionEquals(t, wrapper, session)
+
+ wrapper.Destroy(t)
+}
+
+func newTestPcapImporter(wrapper *TestStorageWrapper, serverIP string) *PcapImporter {
wrapper.AddCollection(ImportingSessions)
- serverEndpoint := layers.NewIPEndpoint(net.ParseIP("172.17.0.3"))
+ serverEndpoint := layers.NewIPEndpoint(net.ParseIP(serverIP))
streamPool := tcpassembly.NewStreamPool(&testStreamFactory{})
- pcapImporter := PcapImporter{
+ return &PcapImporter{
storage: wrapper.Storage,
streamPool: streamPool,
assemblers: make([]*tcpassembly.Assembler, 0, initialAssemblerPoolSize),
@@ -30,40 +96,36 @@ func TestImportPcap(t *testing.T) {
mSessions: sync.Mutex{},
serverIP: serverEndpoint,
}
+}
- sessionID, err := pcapImporter.ImportPcap("test_data/ping_pong_10000.pcap")
- require.NoError(t, err)
- assert.NotZero(t, sessionID)
-
- duplicateSessionID, err := pcapImporter.ImportPcap("test_data/ping_pong_10000.pcap")
- require.Error(t, err)
- assert.Equal(t, sessionID, duplicateSessionID)
-
- _, isPresent := pcapImporter.GetSession("invalid")
- assert.False(t, isPresent)
-
+func waitSessionCompletion(t *testing.T, pcapImporter *PcapImporter, sessionID string) ImportingSession {
session, isPresent := pcapImporter.GetSession(sessionID)
require.True(t, isPresent)
- err, _ = <- session.completed
+ <-session.completed
session, isPresent = pcapImporter.GetSession(sessionID)
- require.True(t, isPresent)
- assert.NoError(t, err)
+ assert.True(t, isPresent)
assert.Equal(t, sessionID, session.ID)
- assert.Equal(t, 15008, session.ProcessedPackets)
- assert.Equal(t, 0, session.InvalidPackets)
- assert.Equal(t, map[uint16]flowCount{9999: {10004, 5004}}, session.PacketsPerService)
- assert.NoError(t, session.ImportingError)
- wrapper.Destroy(t)
+ return session
+}
+
+func checkSessionEquals(t *testing.T, wrapper *TestStorageWrapper, session ImportingSession) {
+ var result ImportingSession
+ assert.NoError(t, wrapper.Storage.Find(ImportingSessions).Filter(OrderedDocument{{"_id", session.ID}}).
+ Context(wrapper.Context).First(&result))
+ assert.Equal(t, session.CompletedAt.Unix(), result.CompletedAt.Unix())
+ session.CompletedAt = time.Time{}
+ result.CompletedAt = time.Time{}
+ session.cancelFunc = nil
+ session.completed = nil
+ assert.Equal(t, session, result)
}
-type testStreamFactory struct{
- counter atomic.Int32
+type testStreamFactory struct {
}
func (sf *testStreamFactory) New(_, _ gopacket.Flow) tcpassembly.Stream {
- sf.counter.Inc()
reader := tcpreader.NewReaderStream()
go func() {
buffer := bufio.NewReader(&reader)
diff --git a/test_data/icmp.pcap b/test_data/icmp.pcap
new file mode 100644
index 0000000..3b6282c
--- /dev/null
+++ b/test_data/icmp.pcap
Binary files differ