diff options
Diffstat (limited to 'pcap_importer.go')
-rw-r--r-- | pcap_importer.go | 192 |
1 files changed, 96 insertions, 96 deletions
diff --git a/pcap_importer.go b/pcap_importer.go index 00c84bd..628b25d 100644 --- a/pcap_importer.go +++ b/pcap_importer.go @@ -7,46 +7,47 @@ import ( "github.com/google/gopacket/layers" "github.com/google/gopacket/pcap" "github.com/google/gopacket/tcpassembly" - "go.mongodb.org/mongo-driver/mongo" - "log" + log "github.com/sirupsen/logrus" "net" - "strconv" "sync" "time" ) const initialAssemblerPoolSize = 16 const flushOlderThan = 5 * time.Minute -const invalidSessionID = "invalid_id" -const importUpdateProgressInterval = 3 * time.Second -const initialPacketPerServicesMapSize = 16 -const importedPcapsCollectionName = "imported_pcaps" +const importUpdateProgressInterval = 100 * time.Millisecond type PcapImporter struct { storage Storage streamPool *tcpassembly.StreamPool assemblers []*tcpassembly.Assembler - sessions map[string]context.CancelFunc + sessions map[string]ImportingSession mAssemblers sync.Mutex mSessions sync.Mutex serverIP gopacket.Endpoint } +type ImportingSession struct { + ID string `json:"id" bson:"_id"` + CompletedAt time.Time `json:"completed_at" bson:"completed_at,omitempty"` + ProcessedPackets int `json:"processed_packets" bson:"processed_packets"` + InvalidPackets int `json:"invalid_packets" bson:"invalid_packets"` + PacketsPerService map[uint16]flowCount `json:"packets_per_service" bson:"packets_per_service"` + ImportingError error `json:"importing_error" bson:"importing_error,omitempty"` + cancelFunc context.CancelFunc +} + type flowCount [2]int -func NewPcapImporter(storage Storage, serverIP net.IP) *PcapImporter { +func NewPcapImporter(storage Storage, serverIP net.IP, rulesManager RulesManager) *PcapImporter { serverEndpoint := layers.NewIPEndpoint(serverIP) - streamFactory := &BiDirectionalStreamFactory{ - storage: storage, - serverIP: serverEndpoint, - } - streamPool := tcpassembly.NewStreamPool(streamFactory) + streamPool := tcpassembly.NewStreamPool(NewBiDirectionalStreamFactory(storage, serverEndpoint, rulesManager)) return &PcapImporter{ storage: storage, streamPool: streamPool, assemblers: make([]*tcpassembly.Assembler, 0, initialAssemblerPoolSize), - sessions: make(map[string]context.CancelFunc), + sessions: make(map[string]ImportingSession), mAssemblers: sync.Mutex{}, mSessions: sync.Mutex{}, serverIP: serverEndpoint, @@ -60,61 +61,83 @@ func NewPcapImporter(storage Storage, serverIP net.IP) *PcapImporter { func (pi *PcapImporter) ImportPcap(fileName string) (string, error) { hash, err := Sha256Sum(fileName) if err != nil { - return invalidSessionID, err + return "", err } pi.mSessions.Lock() - _, ok := pi.sessions[hash] - if ok { + _, isPresent := pi.sessions[hash] + if isPresent { pi.mSessions.Unlock() - return hash, errors.New("another equal session in progress") + return hash, errors.New("pcap already processed") } - doc := OrderedDocument{ - {"_id", hash}, - {"started_at", time.Now()}, - {"completed_at", nil}, - {"processed_packets", 0}, - {"invalid_packets", 0}, - {"packets_per_services", nil}, - {"importing_error", err}, + ctx, cancelFunc := context.WithCancel(context.Background()) + session := ImportingSession{ + ID: hash, + PacketsPerService: make(map[uint16]flowCount), + cancelFunc: cancelFunc, } - ctx, canc := context.WithCancel(context.Background()) - _, err = pi.storage.Insert(importedPcapsCollectionName).Context(ctx).One(doc) - if err != nil { + + if result, err := pi.storage.Insert(ImportingSessions).Context(ctx).One(session); err != nil { pi.mSessions.Unlock() - _, alreadyProcessed := err.(mongo.WriteException) - if alreadyProcessed { - return hash, errors.New("pcap already processed") - } - return hash, err + log.WithError(err).WithField("session", session).Panic("failed to insert a session into database") + } else if result == nil { + pi.mSessions.Unlock() + return hash, errors.New("pcap already processed") } - pi.sessions[hash] = canc + pi.sessions[hash] = session pi.mSessions.Unlock() - go pi.parsePcap(hash, fileName, ctx) + go pi.parsePcap(session, fileName, ctx) return hash, nil } -func (pi *PcapImporter) CancelImport(sessionID string) error { +func (pi *PcapImporter) GetSession(sessionID string) (ImportingSession, bool) { pi.mSessions.Lock() defer pi.mSessions.Unlock() - cancel, ok := pi.sessions[sessionID] - if ok { - delete(pi.sessions, sessionID) - cancel() - return nil - } else { + session, isPresent := pi.sessions[sessionID] + return session, isPresent +} + +func (pi *PcapImporter) CancelSession(sessionID string) error { + pi.mSessions.Lock() + defer pi.mSessions.Unlock() + if session, isPresent := pi.sessions[sessionID]; !isPresent { return errors.New("session " + sessionID + " not found") + } else { + session.cancelFunc() + return nil } } // Read the pcap and save the tcp stream flow to the database -func (pi *PcapImporter) parsePcap(sessionID, fileName string, ctx context.Context) { +func (pi *PcapImporter) parsePcap(session ImportingSession, fileName string, ctx context.Context) { + progressUpdate := func(completed bool, err error) { + if completed { + session.CompletedAt = time.Now() + } + session.ImportingError = err + + dupSession := session + dupSession.PacketsPerService = make(map[uint16]flowCount, len(session.PacketsPerService)) + for key, value := range session.PacketsPerService { + dupSession.PacketsPerService[key] = value + } + + pi.mSessions.Lock() + pi.sessions[session.ID] = dupSession + pi.mSessions.Unlock() + + if _, err = pi.storage.Update(ImportingSessions). + Filter(OrderedDocument{{"_id", session.ID}}).One(session); err != nil { + log.WithError(err).WithField("session", session).Error("failed to update importing stats") + } + } + handle, err := pcap.OpenOffline(fileName) if err != nil { - // TODO: update db and set error + progressUpdate(false, errors.New("failed to process pcap")) return } @@ -125,40 +148,15 @@ func (pi *PcapImporter) parsePcap(sessionID, fileName string, ctx context.Contex firstPacketTime := time.Time{} updateProgressInterval := time.Tick(importUpdateProgressInterval) - processedPackets := 0 - invalidPackets := 0 - packetsPerService := make(map[int]*flowCount, initialPacketPerServicesMapSize) - - progressUpdate := func(completed bool, err error) { - update := UnorderedDocument{ - "processed_packets": processedPackets, - "invalid_packets": invalidPackets, - "packets_per_services": packetsPerService, - "importing_error": err, - } - if completed { - update["completed_at"] = time.Now() - } - - _, _err := pi.storage.Update(importedPcapsCollectionName). - Filter(OrderedDocument{{"_id", sessionID}}). - One(nil) - if _err != nil { - log.Println("can't update importing statistics : ", _err) - } - } - - deleteSession := func() { - pi.mSessions.Lock() - delete(pi.sessions, sessionID) - pi.mSessions.Unlock() + terminate := func() { + handle.Close() + pi.releaseAssembler(assembler) } for { select { case <-ctx.Done(): - handle.Close() - deleteSession() + terminate() progressUpdate(false, errors.New("import process cancelled")) return default: @@ -170,42 +168,44 @@ func (pi *PcapImporter) parsePcap(sessionID, fileName string, ctx context.Contex if !firstPacketTime.IsZero() { assembler.FlushOlderThan(firstPacketTime.Add(-flushOlderThan)) } - pi.releaseAssembler(assembler) - handle.Close() - - deleteSession() + terminate() progressUpdate(true, nil) - return } - processedPackets++ - - if packet.NetworkLayer() == nil || packet.TransportLayer() == nil || - packet.TransportLayer().LayerType() != layers.LayerTypeTCP { // invalid packet - invalidPackets++ - continue - } timestamp := packet.Metadata().Timestamp if firstPacketTime.IsZero() { firstPacketTime = timestamp } + if packet.NetworkLayer() == nil || packet.TransportLayer() == nil || + packet.TransportLayer().LayerType() != layers.LayerTypeTCP { // invalid packet + session.InvalidPackets++ + continue + } + session.ProcessedPackets++ + tcp := packet.TransportLayer().(*layers.TCP) - var servicePort, index int - if packet.NetworkLayer().NetworkFlow().Dst() == pi.serverIP { - servicePort, _ = strconv.Atoi(tcp.DstPort.String()) + var servicePort uint16 + var index int + isDstServer := packet.NetworkLayer().NetworkFlow().Dst() == pi.serverIP + isSrcServer := packet.NetworkLayer().NetworkFlow().Src() == pi.serverIP + if isDstServer && !isSrcServer { + servicePort = uint16(tcp.DstPort) index = 0 - } else { - servicePort, _ = strconv.Atoi(tcp.SrcPort.String()) + } else if isSrcServer && !isDstServer { + servicePort = uint16(tcp.SrcPort) index = 1 + } else { + session.InvalidPackets++ + continue } - fCount, ok := packetsPerService[servicePort] - if !ok { - fCount = &flowCount{0, 0} - packetsPerService[servicePort] = fCount + fCount, isPresent := session.PacketsPerService[servicePort] + if !isPresent { + fCount = flowCount{0, 0} } fCount[index]++ + session.PacketsPerService[servicePort] = fCount assembler.AssembleWithTimestamp(packet.NetworkLayer().NetworkFlow(), tcp, timestamp) case <-updateProgressInterval: |