diff options
author | Emiliano Ciavatta | 2020-04-13 15:12:35 +0000 |
---|---|---|
committer | Emiliano Ciavatta | 2020-04-13 15:12:35 +0000 |
commit | a56a4e391d541ae05de0203f3d493edc3b04681d (patch) | |
tree | ab9344a650305aafb5afe552dc8cad63684de643 | |
parent | 7113463dead05631339fdab94de9440201c42489 (diff) |
Add AddRoute tests
-rw-r--r-- | caronte.go | 3 | ||||
-rw-r--r-- | connection_handler.go | 2 | ||||
-rw-r--r-- | connection_handler_test.go | 32 | ||||
-rw-r--r-- | routes.go | 118 | ||||
-rw-r--r-- | rules_manager.go | 103 | ||||
-rw-r--r-- | rules_manager_test.go | 77 | ||||
-rw-r--r-- | storage.go | 24 | ||||
-rw-r--r-- | stream_handler.go | 2 | ||||
-rw-r--r-- | utils.go | 32 |
9 files changed, 267 insertions, 126 deletions
@@ -24,8 +24,9 @@ func main() { log.WithError(err).Fatal("failed to connect to MongoDB") } + rulesManager := NewRulesManager(storage) router := gin.Default() - ApplicationRoutes(router) + ApplicationRoutes(router, rulesManager) err = router.Run(fmt.Sprintf("%s:%v", *bindAddress, *bindPort)) if err != nil { log.WithError(err).Fatal("failed to create the server") diff --git a/connection_handler.go b/connection_handler.go index b8bddc9..e4730cc 100644 --- a/connection_handler.go +++ b/connection_handler.go @@ -188,7 +188,7 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) { server = handler } - connectionID := ch.Storage().NewCustomRowID(ch.connectionFlow.Hash(), startedAt) + connectionID := CustomRowID(ch.connectionFlow.Hash(), startedAt) connection := Connection{ ID: connectionID, SourceIP: ch.connectionFlow[0].String(), diff --git a/connection_handler_test.go b/connection_handler_test.go index 20eb041..dd097e4 100644 --- a/connection_handler_test.go +++ b/connection_handler_test.go @@ -18,7 +18,7 @@ import ( func TestTakeReleaseScanners(t *testing.T) { wrapper := NewTestStorageWrapper(t) serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP)) - ruleManager := TestRuleManager{ + ruleManager := TestRulesManager{ databaseUpdated: make(chan RulesDatabase), } @@ -26,7 +26,7 @@ func TestTakeReleaseScanners(t *testing.T) { require.NoError(t, err) factory := NewBiDirectionalStreamFactory(wrapper.Storage, serverIP, &ruleManager) - version := wrapper.Storage.NewRowID() + version := NewRowID() ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} time.Sleep(10 * time.Millisecond) @@ -36,7 +36,7 @@ func TestTakeReleaseScanners(t *testing.T) { assert.Equal(t, scanner.version, version) if i%50 == 0 { - version = wrapper.Storage.NewRowID() + version = NewRowID() ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} time.Sleep(10 * time.Millisecond) } @@ -54,7 +54,7 @@ func TestTakeReleaseScanners(t *testing.T) { } assert.Len(t, factory.scanners, n) - version = wrapper.Storage.NewRowID() + version = NewRowID() ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} time.Sleep(10 * time.Millisecond) @@ -73,7 +73,7 @@ func TestConnectionFactory(t *testing.T) { wrapper.AddCollection(Connections) wrapper.AddCollection(ConnectionStreams) - ruleManager := TestRuleManager{ + ruleManager := TestRulesManager{ databaseUpdated: make(chan RulesDatabase), } @@ -89,7 +89,7 @@ func TestConnectionFactory(t *testing.T) { require.NoError(t, err) factory := NewBiDirectionalStreamFactory(wrapper.Storage, serverIP, &ruleManager) - version := wrapper.Storage.NewRowID() + version := NewRowID() ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} time.Sleep(10 * time.Millisecond) @@ -122,13 +122,13 @@ func TestConnectionFactory(t *testing.T) { var result Connection connectionFlow := StreamFlow{netFlow.Src(), netFlow.Dst(), transportFlow.Src(), transportFlow.Dst()} - connectionID := wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), startedAt) + connectionID := CustomRowID(connectionFlow.Hash(), startedAt) op := wrapper.Storage.Find(Connections).Context(wrapper.Context) err := op.Filter(OrderedDocument{{"_id", connectionID}}).First(&result) require.NoError(t, err) assert.NotNil(t, result) - assert.Equal(t, wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), result.StartedAt), result.ID) + assert.Equal(t, CustomRowID(connectionFlow.Hash(), result.StartedAt), result.ID) assert.Equal(t, netFlow.Src().String(), result.SourceIP) assert.Equal(t, netFlow.Dst().String(), result.DestinationIP) assert.Equal(t, binary.BigEndian.Uint16(transportFlow.Src().Raw()), result.SourcePort) @@ -170,33 +170,33 @@ func TestConnectionFactory(t *testing.T) { wrapper.Destroy(t) } -type TestRuleManager struct { +type TestRulesManager struct { databaseUpdated chan RulesDatabase } -func (rm TestRuleManager) LoadRules() error { +func (rm TestRulesManager) LoadRules() error { return nil } -func (rm TestRuleManager) AddRule(_ context.Context, _ Rule) (RowID, error) { +func (rm TestRulesManager) AddRule(_ context.Context, _ Rule) (RowID, error) { return RowID{}, nil } -func (rm TestRuleManager) GetRule(_ RowID) (Rule, bool) { +func (rm TestRulesManager) GetRule(_ RowID) (Rule, bool) { return Rule{}, false } -func (rm TestRuleManager) UpdateRule(_ context.Context, _ Rule) bool { +func (rm TestRulesManager) UpdateRule(_ context.Context, _ RowID, _ Rule) bool { return false } -func (rm TestRuleManager) GetRules() []Rule { +func (rm TestRulesManager) GetRules() []Rule { return nil } -func (rm TestRuleManager) FillWithMatchedRules(_ *Connection, _ map[uint][]PatternSlice, _ map[uint][]PatternSlice) { +func (rm TestRulesManager) FillWithMatchedRules(_ *Connection, _ map[uint][]PatternSlice, _ map[uint][]PatternSlice) { } -func (rm TestRuleManager) DatabaseUpdateChannel() chan RulesDatabase { +func (rm TestRulesManager) DatabaseUpdateChannel() chan RulesDatabase { return rm.databaseUpdated } @@ -1,55 +1,101 @@ package main import ( - "fmt" "github.com/gin-gonic/gin" - log "github.com/sirupsen/logrus" - "github.com/go-playground/validator/v10" "net/http" ) -func ApplicationRoutes(engine *gin.Engine) { - engine.Static("/", "./frontend/build") +func ApplicationRoutes(engine *gin.Engine, rulesManager RulesManager) { + // engine.Static("/", "./frontend/build") api := engine.Group("/api") { - api.POST("/rule", func(c *gin.Context) { + api.GET("/rules", func(c *gin.Context) { + success(c, rulesManager.GetRules()) + }) + + api.POST("/rules", func(c *gin.Context) { var rule Rule - //data, _ := c.GetRawData() - // - //var json = jsoniter.ConfigCompatibleWithStandardLibrary - //err := json.Unmarshal(data, &filter) - // - //if err != nil { - // log.WithError(err).Error("failed to unmarshal") - // c.String(500, "failed to unmarshal") - //} - // - //err = validator.New().Struct(filter) - //log.WithError(err).WithField("filter", filter).Error("aaaa") - //c.String(200, "ok") + if err := c.ShouldBindJSON(&rule); err != nil { + badRequest(c, err) + return + } + + if id, err := rulesManager.AddRule(c, rule); err != nil { + unprocessableEntity(c, err) + } else { + success(c, UnorderedDocument{"id": id}) + } + }) + + api.GET("/rules/:id", func(c *gin.Context) { + hex := c.Param("id") + id, err := RowIDFromHex(hex) + if err != nil { + badRequest(c, err) + return + } + rule, found := rulesManager.GetRule(id) + if !found { + notFound(c, UnorderedDocument{"id": id}) + } else { + success(c, rule) + } + }) + api.PUT("/rules/:id", func(c *gin.Context) { + hex := c.Param("id") + id, err := RowIDFromHex(hex) + if err != nil { + badRequest(c, err) + return + } + var rule Rule if err := c.ShouldBindJSON(&rule); err != nil { - validationErrors, ok := err.(validator.ValidationErrors) - if !ok { - log.WithError(err).WithField("rule", rule).Error("oops") - c.JSON(http.StatusBadRequest, gin.H{}) - return - } - - for _, fieldErr := range validationErrors { - log.Println(fieldErr) - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("field '%v' does not respect the %v(%v) rule", - fieldErr.Field(), fieldErr.Tag(), fieldErr.Param()), - }) - log.WithError(err).WithField("rule", rule).Error("oops") - return // exit on first error - } + badRequest(c, err) + return } - c.JSON(200, rule) + updated := rulesManager.UpdateRule(c, id, rule) + if !updated { + notFound(c, UnorderedDocument{"id": id}) + } else { + success(c, rule) + } }) } } + +func success(c *gin.Context, obj interface{}) { + c.JSON(http.StatusOK, obj) +} + +func badRequest(c *gin.Context, err error) { + c.JSON(http.StatusBadRequest, UnorderedDocument{"result": "error", "error": err.Error()}) + + //validationErrors, ok := err.(validator.ValidationErrors) + //if !ok { + // log.WithError(err).WithField("rule", rule).Error("oops") + // c.JSON(http.StatusBadRequest, gin.H{}) + // return + //} + // + //for _, fieldErr := range validationErrors { + // log.Println(fieldErr) + // c.JSON(http.StatusBadRequest, gin.H{ + // "error": fmt.Sprintf("field '%v' does not respect the %v(%v) rule", + // fieldErr.Field(), fieldErr.Tag(), fieldErr.Param()), + // }) + // log.WithError(err).WithField("rule", rule).Error("oops") + // return // exit on first error + //} +} + +func unprocessableEntity(c *gin.Context, err error) { + c.JSON(http.StatusOK, UnorderedDocument{"result": "error", "error": err.Error()}) +} + +func notFound(c *gin.Context, obj interface{}) { + c.JSON(http.StatusNotFound, obj) +} diff --git a/rules_manager.go b/rules_manager.go index 388aeee..57f8768 100644 --- a/rules_manager.go +++ b/rules_manager.go @@ -2,7 +2,6 @@ package main import ( "context" - "crypto/sha256" "errors" "fmt" "github.com/flier/gohs/hyperscan" @@ -12,6 +11,10 @@ import ( "time" ) +const DirectionBoth = 0 +const DirectionToServer = 1 +const DirectionToClient = 2 + type RegexFlags struct { Caseless bool `json:"caseless" bson:"caseless,omitempty"` // Set case-insensitive matching. DotAll bool `json:"dot_all" bson:"dot_all,omitempty"` // Matching a `.` will not exclude newlines. @@ -26,13 +29,13 @@ type Pattern struct { Flags RegexFlags `json:"flags" bson:"flags,omitempty"` MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"` MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"` + Direction uint8 `json:"direction" binding:"omitempty,max=2" bson:"direction,omitempty"` internalID int - compiledPattern *hyperscan.Pattern } type Filter struct { ServicePort uint16 `json:"service_port" bson:"service_port,omitempty"` - ClientAddress string `json:"client_address" binding:"ip_addr" bson:"client_address,omitempty"` + ClientAddress string `json:"client_address" binding:"omitempty,ip_addr" bson:"client_address,omitempty"` ClientPort uint16 `json:"client_port" bson:"client_port,omitempty"` MinDuration uint `json:"min_duration" bson:"min_duration,omitempty"` MaxDuration uint `json:"max_duration" binding:"omitempty,gtefield=MinDuration" bson:"max_duration,omitempty"` @@ -46,7 +49,7 @@ type Rule struct { Color string `json:"color" binding:"hexcolor" bson:"color"` Notes string `json:"notes" bson:"notes,omitempty"` Enabled bool `json:"enabled" bson:"enabled"` - Patterns []Pattern `json:"patterns" binding:"required,min=1" bson:"patterns"` + Patterns []Pattern `json:"patterns" bson:"patterns"` Filter Filter `json:"filter" bson:"filter,omitempty"` Version int64 `json:"version" bson:"version"` } @@ -61,7 +64,7 @@ type RulesManager interface { LoadRules() error AddRule(context context.Context, rule Rule) (RowID, error) GetRule(id RowID) (Rule, bool) - UpdateRule(context context.Context, rule Rule) bool + UpdateRule(context context.Context, id RowID, rule Rule) bool GetRules() []Rule FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice) DatabaseUpdateChannel() chan RulesDatabase @@ -71,7 +74,8 @@ type rulesManagerImpl struct { storage Storage rules map[RowID]Rule rulesByName map[string]Rule - patterns map[string]Pattern + patterns []*hyperscan.Pattern + patternsIds map[string]int mutex sync.Mutex databaseUpdated chan RulesDatabase validate *validator.Validate @@ -82,9 +86,10 @@ func NewRulesManager(storage Storage) RulesManager { storage: storage, rules: make(map[RowID]Rule), rulesByName: make(map[string]Rule), - patterns: make(map[string]Pattern), + patterns: make([]*hyperscan.Pattern, 0), + patternsIds: make(map[string]int), mutex: sync.Mutex{}, - databaseUpdated: make(chan RulesDatabase), + databaseUpdated: make(chan RulesDatabase, 1), validate: validator.New(), } } @@ -107,12 +112,12 @@ func (rm *rulesManagerImpl) LoadRules() error { func (rm *rulesManagerImpl) AddRule(context context.Context, rule Rule) (RowID, error) { rm.mutex.Lock() - rule.ID = rm.storage.NewCustomRowID(uint64(len(rm.rules)), time.Now()) + rule.ID = CustomRowID(uint64(len(rm.rules)), time.Now()) rule.Enabled = true if err := rm.validateAndAddRuleLocal(&rule); err != nil { rm.mutex.Unlock() - return rule.ID, err + return EmptyRowID(), err } if err := rm.generateDatabase(rule.ID); err != nil { @@ -133,8 +138,13 @@ func (rm *rulesManagerImpl) GetRule(id RowID) (Rule, bool) { return rule, isPresent } -func (rm *rulesManagerImpl) UpdateRule(context context.Context, rule Rule) bool { - updated, err := rm.storage.Update(Rules).Context(context).Filter(OrderedDocument{{"_id", rule.ID}}). +func (rm *rulesManagerImpl) UpdateRule(context context.Context, id RowID, rule Rule) bool { + newRule, isPresent := rm.rules[id] + if !isPresent { + return false + } + + updated, err := rm.storage.Update(Rules).Context(context).Filter(OrderedDocument{{"_id", id}}). One(UnorderedDocument{"name": rule.Name, "color": rule.Color}) if err != nil { log.WithError(err).WithField("rule", rule).Panic("failed to update rule on database") @@ -142,7 +152,6 @@ func (rm *rulesManagerImpl) UpdateRule(context context.Context, rule Rule) bool if updated { rm.mutex.Lock() - newRule := rm.rules[rule.ID] newRule.Name = rule.Name newRule.Color = rule.Color @@ -165,6 +174,19 @@ func (rm *rulesManagerImpl) GetRules() []Rule { return rules } +func (rm *rulesManagerImpl) SetFlag(context context.Context, flagRegex string) error { + _, err := rm.AddRule(context, Rule{ + Name: "flag", + Color: "#ff0000", + Notes: "Mark connections where the flag is stolen", + Patterns: []Pattern{ + {Regex: flagRegex, Direction: DirectionToClient}, + }, + }) + + return err +} + func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice) { } @@ -178,27 +200,31 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error { return errors.New("rule name must be unique") } - newPatterns := make(map[string]Pattern) + newPatterns := make([]*hyperscan.Pattern, 0, len(rule.Patterns)) for i, pattern := range rule.Patterns { if err := rm.validate.Struct(pattern); err != nil { return err } - hash := pattern.Hash() - if existingPattern, isPresent := rm.patterns[hash]; isPresent { - rule.Patterns[i] = existingPattern + compiledPattern, err := pattern.BuildPattern() + if err != nil { + return err + } + if existingPattern, isPresent := rm.patternsIds[compiledPattern.String()]; isPresent { + rule.Patterns[i].internalID = existingPattern continue } - if err := pattern.BuildPattern(); err != nil { - return err - } - pattern.internalID = len(rm.patterns) + len(newPatterns) - newPatterns[hash] = pattern + id := len(rm.patternsIds) + len(newPatterns) + rule.Patterns[i].internalID = id + compiledPattern.Id = id + newPatterns = append(newPatterns, compiledPattern) } - for key, value := range newPatterns { - rm.patterns[key] = value + startId := len(rm.patterns) + for id, pattern := range newPatterns { + rm.patterns = append(rm.patterns, pattern) + rm.patternsIds[pattern.String()] = startId + id } rm.rules[rule.ID] = *rule @@ -208,31 +234,24 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error { } func (rm *rulesManagerImpl) generateDatabase(version RowID) error { - patterns := make([]*hyperscan.Pattern, 0, len(rm.patterns)) - for _, pattern := range rm.patterns { - patterns = append(patterns, pattern.compiledPattern) - } - database, err := hyperscan.NewStreamDatabase(patterns...) + database, err := hyperscan.NewStreamDatabase(rm.patterns...) if err != nil { return err } rm.databaseUpdated <- RulesDatabase{ database: database, - databaseSize: len(patterns), + databaseSize: len(rm.patterns), version: version, } + return nil } -func (p *Pattern) BuildPattern() error { - if p.compiledPattern != nil { - return nil - } - +func (p *Pattern) BuildPattern() (*hyperscan.Pattern, error) { hp, err := hyperscan.ParsePattern(fmt.Sprintf("/%s/", p.Regex)) if err != nil { - return err + return nil, err } if p.Flags.Caseless { @@ -255,16 +274,8 @@ func (p *Pattern) BuildPattern() error { } if !hp.IsValid() { - return errors.New("can't validate the pattern") + return nil, errors.New("can't validate the pattern") } - p.compiledPattern = hp - - return nil -} - -func (p Pattern) Hash() string { - hash := sha256.New() - hash.Write([]byte(fmt.Sprintf("%s|%v|%v|%v", p.Regex, p.Flags, p.MinOccurrences, p.MaxOccurrences))) - return fmt.Sprintf("%x", hash.Sum(nil)) + return hp, nil } diff --git a/rules_manager_test.go b/rules_manager_test.go new file mode 100644 index 0000000..53d085d --- /dev/null +++ b/rules_manager_test.go @@ -0,0 +1,77 @@ +package main + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestAddRule(t *testing.T) { + wrapper := NewTestStorageWrapper(t) + wrapper.AddCollection(Rules) + + rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl) + + checkVersion := func(id RowID) { + timeout := time.Tick(1 * time.Second) + + select { + case database := <-rulesManager.databaseUpdated: + assert.Equal(t, id, database.version) + case <-timeout: + t.Fatal("timeout") + } + } + + err := rulesManager.SetFlag(wrapper.Context, "FLAG{test}") + assert.NoError(t, err) + checkVersion(rulesManager.rulesByName["flag"].ID) + emptyID, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#fff"}) + assert.NoError(t, err) + assert.NotNil(t, emptyID) + checkVersion(emptyID) + + rule1 := Rule{ + Name: "rule1", + Color: "#eee", + Patterns: []Pattern{ + {Regex: "nope", Flags: RegexFlags{Caseless: true}}, + }, + } + rule1ID, err := rulesManager.AddRule(wrapper.Context, rule1) + assert.NoError(t, err) + assert.NotNil(t, rule1ID) + checkVersion(rule1ID) + + rule2 := Rule{ + Name: "rule2", + Color: "#ddd", + Patterns: []Pattern{ + {Regex: "nope", Flags: RegexFlags{Caseless: true}}, + {Regex: "yep"}, + }, + } + rule2ID, err := rulesManager.AddRule(wrapper.Context, rule2) + assert.NoError(t, err) + assert.NotNil(t, rule2ID) + checkVersion(rule2ID) + + assert.Len(t, rulesManager.rules, 4) + assert.Len(t, rulesManager.rulesByName, 4) + assert.Len(t, rulesManager.patterns, 3) + assert.Len(t, rulesManager.patternsIds, 3) + assert.Equal(t, emptyID, rulesManager.rules[emptyID].ID) + assert.Equal(t, emptyID, rulesManager.rulesByName["empty"].ID) + assert.Len(t, rulesManager.rules[emptyID].Patterns, 0) + assert.Equal(t, rule1ID, rulesManager.rules[rule1ID].ID) + assert.Equal(t, rule1ID, rulesManager.rulesByName[rule1.Name].ID) + assert.Len(t, rulesManager.rules[rule1ID].Patterns, 1) + assert.Equal(t, 1, rulesManager.rules[rule1ID].Patterns[0].internalID) + assert.Equal(t, rule2ID, rulesManager.rules[rule2ID].ID) + assert.Equal(t, rule2ID, rulesManager.rulesByName[rule2.Name].ID) + assert.Len(t, rulesManager.rules[rule2ID].Patterns, 2) + assert.Equal(t, 1, rulesManager.rules[rule2ID].Patterns[0].internalID) + assert.Equal(t, 2, rulesManager.rules[rule2ID].Patterns[1].internalID) + + wrapper.Destroy(t) +} @@ -2,8 +2,6 @@ package main import ( "context" - "encoding/binary" - "encoding/hex" "errors" "fmt" log "github.com/sirupsen/logrus" @@ -11,7 +9,6 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" - "time" ) // Collections names @@ -20,16 +17,12 @@ const ConnectionStreams = "connection_streams" const ImportedPcaps = "imported_pcaps" const Rules = "rules" -const defaultConnectionTimeout = 10 * time.Second - var ZeroRowID [12]byte type Storage interface { Insert(collectionName string) InsertOperation Update(collectionName string) UpdateOperation Find(collectionName string) FindOperation - NewCustomRowID(payload uint64, timestamp time.Time) RowID - NewRowID() RowID } type MongoStorage struct { @@ -67,23 +60,6 @@ func (storage *MongoStorage) Connect(ctx context.Context) error { return storage.client.Connect(ctx) } -func (storage *MongoStorage) NewCustomRowID(payload uint64, timestamp time.Time) RowID { - var key [12]byte - binary.BigEndian.PutUint32(key[0:4], uint32(timestamp.Unix())) - binary.BigEndian.PutUint64(key[4:12], payload) - - if oid, err := primitive.ObjectIDFromHex(hex.EncodeToString(key[:])); err == nil { - return oid - } else { - log.WithError(err).Warn("failed to create object id") - return primitive.NewObjectID() - } -} - -func (storage *MongoStorage) NewRowID() RowID { - return primitive.NewObjectID() -} - // InsertOne and InsertMany type InsertOperation interface { diff --git a/stream_handler.go b/stream_handler.go index 78326c6..a436fd5 100644 --- a/stream_handler.go +++ b/stream_handler.go @@ -147,7 +147,7 @@ func (sh *StreamHandler) onMatch(id uint, from uint64, to uint64, _ uint, _ inte func (sh *StreamHandler) storageCurrentDocument() { payload := sh.streamFlow.Hash()&uint64(0xffffffffffffff00) | uint64(len(sh.documentsIDs)) // LOL - streamID := sh.connection.Storage().NewCustomRowID(payload, sh.firstPacketSeen) + streamID := CustomRowID(payload, sh.firstPacketSeen) _, err := sh.connection.Storage().Insert(ConnectionStreams). One(ConnectionStream{ @@ -2,9 +2,13 @@ package main import ( "crypto/sha256" + "encoding/binary" + "encoding/hex" + log "github.com/sirupsen/logrus" + "go.mongodb.org/mongo-driver/bson/primitive" "io" - "log" "os" + "time" ) const invalidHashString = "invalid" @@ -28,3 +32,29 @@ func Sha256Sum(fileName string) (string, error) { return string(h.Sum(nil)), nil } + +func CustomRowID(payload uint64, timestamp time.Time) RowID { + var key [12]byte + binary.BigEndian.PutUint32(key[0:4], uint32(timestamp.Unix())) + binary.BigEndian.PutUint64(key[4:12], payload) + + if oid, err := primitive.ObjectIDFromHex(hex.EncodeToString(key[:])); err == nil { + return oid + } else { + log.WithError(err).Warn("failed to create object id") + return primitive.NewObjectID() + } +} + +func NewRowID() RowID { + return primitive.NewObjectID() +} + +func EmptyRowID() RowID { + return [12]byte{} +} + +func RowIDFromHex(hex string) (RowID, error) { + rowID, err := primitive.ObjectIDFromHex(hex) + return rowID, err +} |