aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEmiliano Ciavatta2020-04-13 15:12:35 +0000
committerEmiliano Ciavatta2020-04-13 15:12:35 +0000
commita56a4e391d541ae05de0203f3d493edc3b04681d (patch)
treeab9344a650305aafb5afe552dc8cad63684de643
parent7113463dead05631339fdab94de9440201c42489 (diff)
Add AddRoute tests
-rw-r--r--caronte.go3
-rw-r--r--connection_handler.go2
-rw-r--r--connection_handler_test.go32
-rw-r--r--routes.go118
-rw-r--r--rules_manager.go103
-rw-r--r--rules_manager_test.go77
-rw-r--r--storage.go24
-rw-r--r--stream_handler.go2
-rw-r--r--utils.go32
9 files changed, 267 insertions, 126 deletions
diff --git a/caronte.go b/caronte.go
index a6fa584..c1a8a29 100644
--- a/caronte.go
+++ b/caronte.go
@@ -24,8 +24,9 @@ func main() {
log.WithError(err).Fatal("failed to connect to MongoDB")
}
+ rulesManager := NewRulesManager(storage)
router := gin.Default()
- ApplicationRoutes(router)
+ ApplicationRoutes(router, rulesManager)
err = router.Run(fmt.Sprintf("%s:%v", *bindAddress, *bindPort))
if err != nil {
log.WithError(err).Fatal("failed to create the server")
diff --git a/connection_handler.go b/connection_handler.go
index b8bddc9..e4730cc 100644
--- a/connection_handler.go
+++ b/connection_handler.go
@@ -188,7 +188,7 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) {
server = handler
}
- connectionID := ch.Storage().NewCustomRowID(ch.connectionFlow.Hash(), startedAt)
+ connectionID := CustomRowID(ch.connectionFlow.Hash(), startedAt)
connection := Connection{
ID: connectionID,
SourceIP: ch.connectionFlow[0].String(),
diff --git a/connection_handler_test.go b/connection_handler_test.go
index 20eb041..dd097e4 100644
--- a/connection_handler_test.go
+++ b/connection_handler_test.go
@@ -18,7 +18,7 @@ import (
func TestTakeReleaseScanners(t *testing.T) {
wrapper := NewTestStorageWrapper(t)
serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP))
- ruleManager := TestRuleManager{
+ ruleManager := TestRulesManager{
databaseUpdated: make(chan RulesDatabase),
}
@@ -26,7 +26,7 @@ func TestTakeReleaseScanners(t *testing.T) {
require.NoError(t, err)
factory := NewBiDirectionalStreamFactory(wrapper.Storage, serverIP, &ruleManager)
- version := wrapper.Storage.NewRowID()
+ version := NewRowID()
ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
time.Sleep(10 * time.Millisecond)
@@ -36,7 +36,7 @@ func TestTakeReleaseScanners(t *testing.T) {
assert.Equal(t, scanner.version, version)
if i%50 == 0 {
- version = wrapper.Storage.NewRowID()
+ version = NewRowID()
ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
time.Sleep(10 * time.Millisecond)
}
@@ -54,7 +54,7 @@ func TestTakeReleaseScanners(t *testing.T) {
}
assert.Len(t, factory.scanners, n)
- version = wrapper.Storage.NewRowID()
+ version = NewRowID()
ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
time.Sleep(10 * time.Millisecond)
@@ -73,7 +73,7 @@ func TestConnectionFactory(t *testing.T) {
wrapper.AddCollection(Connections)
wrapper.AddCollection(ConnectionStreams)
- ruleManager := TestRuleManager{
+ ruleManager := TestRulesManager{
databaseUpdated: make(chan RulesDatabase),
}
@@ -89,7 +89,7 @@ func TestConnectionFactory(t *testing.T) {
require.NoError(t, err)
factory := NewBiDirectionalStreamFactory(wrapper.Storage, serverIP, &ruleManager)
- version := wrapper.Storage.NewRowID()
+ version := NewRowID()
ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
time.Sleep(10 * time.Millisecond)
@@ -122,13 +122,13 @@ func TestConnectionFactory(t *testing.T) {
var result Connection
connectionFlow := StreamFlow{netFlow.Src(), netFlow.Dst(), transportFlow.Src(), transportFlow.Dst()}
- connectionID := wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), startedAt)
+ connectionID := CustomRowID(connectionFlow.Hash(), startedAt)
op := wrapper.Storage.Find(Connections).Context(wrapper.Context)
err := op.Filter(OrderedDocument{{"_id", connectionID}}).First(&result)
require.NoError(t, err)
assert.NotNil(t, result)
- assert.Equal(t, wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), result.StartedAt), result.ID)
+ assert.Equal(t, CustomRowID(connectionFlow.Hash(), result.StartedAt), result.ID)
assert.Equal(t, netFlow.Src().String(), result.SourceIP)
assert.Equal(t, netFlow.Dst().String(), result.DestinationIP)
assert.Equal(t, binary.BigEndian.Uint16(transportFlow.Src().Raw()), result.SourcePort)
@@ -170,33 +170,33 @@ func TestConnectionFactory(t *testing.T) {
wrapper.Destroy(t)
}
-type TestRuleManager struct {
+type TestRulesManager struct {
databaseUpdated chan RulesDatabase
}
-func (rm TestRuleManager) LoadRules() error {
+func (rm TestRulesManager) LoadRules() error {
return nil
}
-func (rm TestRuleManager) AddRule(_ context.Context, _ Rule) (RowID, error) {
+func (rm TestRulesManager) AddRule(_ context.Context, _ Rule) (RowID, error) {
return RowID{}, nil
}
-func (rm TestRuleManager) GetRule(_ RowID) (Rule, bool) {
+func (rm TestRulesManager) GetRule(_ RowID) (Rule, bool) {
return Rule{}, false
}
-func (rm TestRuleManager) UpdateRule(_ context.Context, _ Rule) bool {
+func (rm TestRulesManager) UpdateRule(_ context.Context, _ RowID, _ Rule) bool {
return false
}
-func (rm TestRuleManager) GetRules() []Rule {
+func (rm TestRulesManager) GetRules() []Rule {
return nil
}
-func (rm TestRuleManager) FillWithMatchedRules(_ *Connection, _ map[uint][]PatternSlice, _ map[uint][]PatternSlice) {
+func (rm TestRulesManager) FillWithMatchedRules(_ *Connection, _ map[uint][]PatternSlice, _ map[uint][]PatternSlice) {
}
-func (rm TestRuleManager) DatabaseUpdateChannel() chan RulesDatabase {
+func (rm TestRulesManager) DatabaseUpdateChannel() chan RulesDatabase {
return rm.databaseUpdated
}
diff --git a/routes.go b/routes.go
index 37088e2..111dd78 100644
--- a/routes.go
+++ b/routes.go
@@ -1,55 +1,101 @@
package main
import (
- "fmt"
"github.com/gin-gonic/gin"
- log "github.com/sirupsen/logrus"
- "github.com/go-playground/validator/v10"
"net/http"
)
-func ApplicationRoutes(engine *gin.Engine) {
- engine.Static("/", "./frontend/build")
+func ApplicationRoutes(engine *gin.Engine, rulesManager RulesManager) {
+ // engine.Static("/", "./frontend/build")
api := engine.Group("/api")
{
- api.POST("/rule", func(c *gin.Context) {
+ api.GET("/rules", func(c *gin.Context) {
+ success(c, rulesManager.GetRules())
+ })
+
+ api.POST("/rules", func(c *gin.Context) {
var rule Rule
- //data, _ := c.GetRawData()
- //
- //var json = jsoniter.ConfigCompatibleWithStandardLibrary
- //err := json.Unmarshal(data, &filter)
- //
- //if err != nil {
- // log.WithError(err).Error("failed to unmarshal")
- // c.String(500, "failed to unmarshal")
- //}
- //
- //err = validator.New().Struct(filter)
- //log.WithError(err).WithField("filter", filter).Error("aaaa")
- //c.String(200, "ok")
+ if err := c.ShouldBindJSON(&rule); err != nil {
+ badRequest(c, err)
+ return
+ }
+
+ if id, err := rulesManager.AddRule(c, rule); err != nil {
+ unprocessableEntity(c, err)
+ } else {
+ success(c, UnorderedDocument{"id": id})
+ }
+ })
+
+ api.GET("/rules/:id", func(c *gin.Context) {
+ hex := c.Param("id")
+ id, err := RowIDFromHex(hex)
+ if err != nil {
+ badRequest(c, err)
+ return
+ }
+ rule, found := rulesManager.GetRule(id)
+ if !found {
+ notFound(c, UnorderedDocument{"id": id})
+ } else {
+ success(c, rule)
+ }
+ })
+ api.PUT("/rules/:id", func(c *gin.Context) {
+ hex := c.Param("id")
+ id, err := RowIDFromHex(hex)
+ if err != nil {
+ badRequest(c, err)
+ return
+ }
+ var rule Rule
if err := c.ShouldBindJSON(&rule); err != nil {
- validationErrors, ok := err.(validator.ValidationErrors)
- if !ok {
- log.WithError(err).WithField("rule", rule).Error("oops")
- c.JSON(http.StatusBadRequest, gin.H{})
- return
- }
-
- for _, fieldErr := range validationErrors {
- log.Println(fieldErr)
- c.JSON(http.StatusBadRequest, gin.H{
- "error": fmt.Sprintf("field '%v' does not respect the %v(%v) rule",
- fieldErr.Field(), fieldErr.Tag(), fieldErr.Param()),
- })
- log.WithError(err).WithField("rule", rule).Error("oops")
- return // exit on first error
- }
+ badRequest(c, err)
+ return
}
- c.JSON(200, rule)
+ updated := rulesManager.UpdateRule(c, id, rule)
+ if !updated {
+ notFound(c, UnorderedDocument{"id": id})
+ } else {
+ success(c, rule)
+ }
})
}
}
+
+func success(c *gin.Context, obj interface{}) {
+ c.JSON(http.StatusOK, obj)
+}
+
+func badRequest(c *gin.Context, err error) {
+ c.JSON(http.StatusBadRequest, UnorderedDocument{"result": "error", "error": err.Error()})
+
+ //validationErrors, ok := err.(validator.ValidationErrors)
+ //if !ok {
+ // log.WithError(err).WithField("rule", rule).Error("oops")
+ // c.JSON(http.StatusBadRequest, gin.H{})
+ // return
+ //}
+ //
+ //for _, fieldErr := range validationErrors {
+ // log.Println(fieldErr)
+ // c.JSON(http.StatusBadRequest, gin.H{
+ // "error": fmt.Sprintf("field '%v' does not respect the %v(%v) rule",
+ // fieldErr.Field(), fieldErr.Tag(), fieldErr.Param()),
+ // })
+ // log.WithError(err).WithField("rule", rule).Error("oops")
+ // return // exit on first error
+ //}
+}
+
+func unprocessableEntity(c *gin.Context, err error) {
+ c.JSON(http.StatusOK, UnorderedDocument{"result": "error", "error": err.Error()})
+}
+
+func notFound(c *gin.Context, obj interface{}) {
+ c.JSON(http.StatusNotFound, obj)
+}
diff --git a/rules_manager.go b/rules_manager.go
index 388aeee..57f8768 100644
--- a/rules_manager.go
+++ b/rules_manager.go
@@ -2,7 +2,6 @@ package main
import (
"context"
- "crypto/sha256"
"errors"
"fmt"
"github.com/flier/gohs/hyperscan"
@@ -12,6 +11,10 @@ import (
"time"
)
+const DirectionBoth = 0
+const DirectionToServer = 1
+const DirectionToClient = 2
+
type RegexFlags struct {
Caseless bool `json:"caseless" bson:"caseless,omitempty"` // Set case-insensitive matching.
DotAll bool `json:"dot_all" bson:"dot_all,omitempty"` // Matching a `.` will not exclude newlines.
@@ -26,13 +29,13 @@ type Pattern struct {
Flags RegexFlags `json:"flags" bson:"flags,omitempty"`
MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"`
MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"`
+ Direction uint8 `json:"direction" binding:"omitempty,max=2" bson:"direction,omitempty"`
internalID int
- compiledPattern *hyperscan.Pattern
}
type Filter struct {
ServicePort uint16 `json:"service_port" bson:"service_port,omitempty"`
- ClientAddress string `json:"client_address" binding:"ip_addr" bson:"client_address,omitempty"`
+ ClientAddress string `json:"client_address" binding:"omitempty,ip_addr" bson:"client_address,omitempty"`
ClientPort uint16 `json:"client_port" bson:"client_port,omitempty"`
MinDuration uint `json:"min_duration" bson:"min_duration,omitempty"`
MaxDuration uint `json:"max_duration" binding:"omitempty,gtefield=MinDuration" bson:"max_duration,omitempty"`
@@ -46,7 +49,7 @@ type Rule struct {
Color string `json:"color" binding:"hexcolor" bson:"color"`
Notes string `json:"notes" bson:"notes,omitempty"`
Enabled bool `json:"enabled" bson:"enabled"`
- Patterns []Pattern `json:"patterns" binding:"required,min=1" bson:"patterns"`
+ Patterns []Pattern `json:"patterns" bson:"patterns"`
Filter Filter `json:"filter" bson:"filter,omitempty"`
Version int64 `json:"version" bson:"version"`
}
@@ -61,7 +64,7 @@ type RulesManager interface {
LoadRules() error
AddRule(context context.Context, rule Rule) (RowID, error)
GetRule(id RowID) (Rule, bool)
- UpdateRule(context context.Context, rule Rule) bool
+ UpdateRule(context context.Context, id RowID, rule Rule) bool
GetRules() []Rule
FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice)
DatabaseUpdateChannel() chan RulesDatabase
@@ -71,7 +74,8 @@ type rulesManagerImpl struct {
storage Storage
rules map[RowID]Rule
rulesByName map[string]Rule
- patterns map[string]Pattern
+ patterns []*hyperscan.Pattern
+ patternsIds map[string]int
mutex sync.Mutex
databaseUpdated chan RulesDatabase
validate *validator.Validate
@@ -82,9 +86,10 @@ func NewRulesManager(storage Storage) RulesManager {
storage: storage,
rules: make(map[RowID]Rule),
rulesByName: make(map[string]Rule),
- patterns: make(map[string]Pattern),
+ patterns: make([]*hyperscan.Pattern, 0),
+ patternsIds: make(map[string]int),
mutex: sync.Mutex{},
- databaseUpdated: make(chan RulesDatabase),
+ databaseUpdated: make(chan RulesDatabase, 1),
validate: validator.New(),
}
}
@@ -107,12 +112,12 @@ func (rm *rulesManagerImpl) LoadRules() error {
func (rm *rulesManagerImpl) AddRule(context context.Context, rule Rule) (RowID, error) {
rm.mutex.Lock()
- rule.ID = rm.storage.NewCustomRowID(uint64(len(rm.rules)), time.Now())
+ rule.ID = CustomRowID(uint64(len(rm.rules)), time.Now())
rule.Enabled = true
if err := rm.validateAndAddRuleLocal(&rule); err != nil {
rm.mutex.Unlock()
- return rule.ID, err
+ return EmptyRowID(), err
}
if err := rm.generateDatabase(rule.ID); err != nil {
@@ -133,8 +138,13 @@ func (rm *rulesManagerImpl) GetRule(id RowID) (Rule, bool) {
return rule, isPresent
}
-func (rm *rulesManagerImpl) UpdateRule(context context.Context, rule Rule) bool {
- updated, err := rm.storage.Update(Rules).Context(context).Filter(OrderedDocument{{"_id", rule.ID}}).
+func (rm *rulesManagerImpl) UpdateRule(context context.Context, id RowID, rule Rule) bool {
+ newRule, isPresent := rm.rules[id]
+ if !isPresent {
+ return false
+ }
+
+ updated, err := rm.storage.Update(Rules).Context(context).Filter(OrderedDocument{{"_id", id}}).
One(UnorderedDocument{"name": rule.Name, "color": rule.Color})
if err != nil {
log.WithError(err).WithField("rule", rule).Panic("failed to update rule on database")
@@ -142,7 +152,6 @@ func (rm *rulesManagerImpl) UpdateRule(context context.Context, rule Rule) bool
if updated {
rm.mutex.Lock()
- newRule := rm.rules[rule.ID]
newRule.Name = rule.Name
newRule.Color = rule.Color
@@ -165,6 +174,19 @@ func (rm *rulesManagerImpl) GetRules() []Rule {
return rules
}
+func (rm *rulesManagerImpl) SetFlag(context context.Context, flagRegex string) error {
+ _, err := rm.AddRule(context, Rule{
+ Name: "flag",
+ Color: "#ff0000",
+ Notes: "Mark connections where the flag is stolen",
+ Patterns: []Pattern{
+ {Regex: flagRegex, Direction: DirectionToClient},
+ },
+ })
+
+ return err
+}
+
func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice,
serverMatches map[uint][]PatternSlice) {
}
@@ -178,27 +200,31 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error {
return errors.New("rule name must be unique")
}
- newPatterns := make(map[string]Pattern)
+ newPatterns := make([]*hyperscan.Pattern, 0, len(rule.Patterns))
for i, pattern := range rule.Patterns {
if err := rm.validate.Struct(pattern); err != nil {
return err
}
- hash := pattern.Hash()
- if existingPattern, isPresent := rm.patterns[hash]; isPresent {
- rule.Patterns[i] = existingPattern
+ compiledPattern, err := pattern.BuildPattern()
+ if err != nil {
+ return err
+ }
+ if existingPattern, isPresent := rm.patternsIds[compiledPattern.String()]; isPresent {
+ rule.Patterns[i].internalID = existingPattern
continue
}
- if err := pattern.BuildPattern(); err != nil {
- return err
- }
- pattern.internalID = len(rm.patterns) + len(newPatterns)
- newPatterns[hash] = pattern
+ id := len(rm.patternsIds) + len(newPatterns)
+ rule.Patterns[i].internalID = id
+ compiledPattern.Id = id
+ newPatterns = append(newPatterns, compiledPattern)
}
- for key, value := range newPatterns {
- rm.patterns[key] = value
+ startId := len(rm.patterns)
+ for id, pattern := range newPatterns {
+ rm.patterns = append(rm.patterns, pattern)
+ rm.patternsIds[pattern.String()] = startId + id
}
rm.rules[rule.ID] = *rule
@@ -208,31 +234,24 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error {
}
func (rm *rulesManagerImpl) generateDatabase(version RowID) error {
- patterns := make([]*hyperscan.Pattern, 0, len(rm.patterns))
- for _, pattern := range rm.patterns {
- patterns = append(patterns, pattern.compiledPattern)
- }
- database, err := hyperscan.NewStreamDatabase(patterns...)
+ database, err := hyperscan.NewStreamDatabase(rm.patterns...)
if err != nil {
return err
}
rm.databaseUpdated <- RulesDatabase{
database: database,
- databaseSize: len(patterns),
+ databaseSize: len(rm.patterns),
version: version,
}
+
return nil
}
-func (p *Pattern) BuildPattern() error {
- if p.compiledPattern != nil {
- return nil
- }
-
+func (p *Pattern) BuildPattern() (*hyperscan.Pattern, error) {
hp, err := hyperscan.ParsePattern(fmt.Sprintf("/%s/", p.Regex))
if err != nil {
- return err
+ return nil, err
}
if p.Flags.Caseless {
@@ -255,16 +274,8 @@ func (p *Pattern) BuildPattern() error {
}
if !hp.IsValid() {
- return errors.New("can't validate the pattern")
+ return nil, errors.New("can't validate the pattern")
}
- p.compiledPattern = hp
-
- return nil
-}
-
-func (p Pattern) Hash() string {
- hash := sha256.New()
- hash.Write([]byte(fmt.Sprintf("%s|%v|%v|%v", p.Regex, p.Flags, p.MinOccurrences, p.MaxOccurrences)))
- return fmt.Sprintf("%x", hash.Sum(nil))
+ return hp, nil
}
diff --git a/rules_manager_test.go b/rules_manager_test.go
new file mode 100644
index 0000000..53d085d
--- /dev/null
+++ b/rules_manager_test.go
@@ -0,0 +1,77 @@
+package main
+
+import (
+ "github.com/stretchr/testify/assert"
+ "testing"
+ "time"
+)
+
+func TestAddRule(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ wrapper.AddCollection(Rules)
+
+ rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl)
+
+ checkVersion := func(id RowID) {
+ timeout := time.Tick(1 * time.Second)
+
+ select {
+ case database := <-rulesManager.databaseUpdated:
+ assert.Equal(t, id, database.version)
+ case <-timeout:
+ t.Fatal("timeout")
+ }
+ }
+
+ err := rulesManager.SetFlag(wrapper.Context, "FLAG{test}")
+ assert.NoError(t, err)
+ checkVersion(rulesManager.rulesByName["flag"].ID)
+ emptyID, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#fff"})
+ assert.NoError(t, err)
+ assert.NotNil(t, emptyID)
+ checkVersion(emptyID)
+
+ rule1 := Rule{
+ Name: "rule1",
+ Color: "#eee",
+ Patterns: []Pattern{
+ {Regex: "nope", Flags: RegexFlags{Caseless: true}},
+ },
+ }
+ rule1ID, err := rulesManager.AddRule(wrapper.Context, rule1)
+ assert.NoError(t, err)
+ assert.NotNil(t, rule1ID)
+ checkVersion(rule1ID)
+
+ rule2 := Rule{
+ Name: "rule2",
+ Color: "#ddd",
+ Patterns: []Pattern{
+ {Regex: "nope", Flags: RegexFlags{Caseless: true}},
+ {Regex: "yep"},
+ },
+ }
+ rule2ID, err := rulesManager.AddRule(wrapper.Context, rule2)
+ assert.NoError(t, err)
+ assert.NotNil(t, rule2ID)
+ checkVersion(rule2ID)
+
+ assert.Len(t, rulesManager.rules, 4)
+ assert.Len(t, rulesManager.rulesByName, 4)
+ assert.Len(t, rulesManager.patterns, 3)
+ assert.Len(t, rulesManager.patternsIds, 3)
+ assert.Equal(t, emptyID, rulesManager.rules[emptyID].ID)
+ assert.Equal(t, emptyID, rulesManager.rulesByName["empty"].ID)
+ assert.Len(t, rulesManager.rules[emptyID].Patterns, 0)
+ assert.Equal(t, rule1ID, rulesManager.rules[rule1ID].ID)
+ assert.Equal(t, rule1ID, rulesManager.rulesByName[rule1.Name].ID)
+ assert.Len(t, rulesManager.rules[rule1ID].Patterns, 1)
+ assert.Equal(t, 1, rulesManager.rules[rule1ID].Patterns[0].internalID)
+ assert.Equal(t, rule2ID, rulesManager.rules[rule2ID].ID)
+ assert.Equal(t, rule2ID, rulesManager.rulesByName[rule2.Name].ID)
+ assert.Len(t, rulesManager.rules[rule2ID].Patterns, 2)
+ assert.Equal(t, 1, rulesManager.rules[rule2ID].Patterns[0].internalID)
+ assert.Equal(t, 2, rulesManager.rules[rule2ID].Patterns[1].internalID)
+
+ wrapper.Destroy(t)
+}
diff --git a/storage.go b/storage.go
index 7d98ba0..5ee9f3e 100644
--- a/storage.go
+++ b/storage.go
@@ -2,8 +2,6 @@ package main
import (
"context"
- "encoding/binary"
- "encoding/hex"
"errors"
"fmt"
log "github.com/sirupsen/logrus"
@@ -11,7 +9,6 @@ import (
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
- "time"
)
// Collections names
@@ -20,16 +17,12 @@ const ConnectionStreams = "connection_streams"
const ImportedPcaps = "imported_pcaps"
const Rules = "rules"
-const defaultConnectionTimeout = 10 * time.Second
-
var ZeroRowID [12]byte
type Storage interface {
Insert(collectionName string) InsertOperation
Update(collectionName string) UpdateOperation
Find(collectionName string) FindOperation
- NewCustomRowID(payload uint64, timestamp time.Time) RowID
- NewRowID() RowID
}
type MongoStorage struct {
@@ -67,23 +60,6 @@ func (storage *MongoStorage) Connect(ctx context.Context) error {
return storage.client.Connect(ctx)
}
-func (storage *MongoStorage) NewCustomRowID(payload uint64, timestamp time.Time) RowID {
- var key [12]byte
- binary.BigEndian.PutUint32(key[0:4], uint32(timestamp.Unix()))
- binary.BigEndian.PutUint64(key[4:12], payload)
-
- if oid, err := primitive.ObjectIDFromHex(hex.EncodeToString(key[:])); err == nil {
- return oid
- } else {
- log.WithError(err).Warn("failed to create object id")
- return primitive.NewObjectID()
- }
-}
-
-func (storage *MongoStorage) NewRowID() RowID {
- return primitive.NewObjectID()
-}
-
// InsertOne and InsertMany
type InsertOperation interface {
diff --git a/stream_handler.go b/stream_handler.go
index 78326c6..a436fd5 100644
--- a/stream_handler.go
+++ b/stream_handler.go
@@ -147,7 +147,7 @@ func (sh *StreamHandler) onMatch(id uint, from uint64, to uint64, _ uint, _ inte
func (sh *StreamHandler) storageCurrentDocument() {
payload := sh.streamFlow.Hash()&uint64(0xffffffffffffff00) | uint64(len(sh.documentsIDs)) // LOL
- streamID := sh.connection.Storage().NewCustomRowID(payload, sh.firstPacketSeen)
+ streamID := CustomRowID(payload, sh.firstPacketSeen)
_, err := sh.connection.Storage().Insert(ConnectionStreams).
One(ConnectionStream{
diff --git a/utils.go b/utils.go
index cc99d93..cb60ea6 100644
--- a/utils.go
+++ b/utils.go
@@ -2,9 +2,13 @@ package main
import (
"crypto/sha256"
+ "encoding/binary"
+ "encoding/hex"
+ log "github.com/sirupsen/logrus"
+ "go.mongodb.org/mongo-driver/bson/primitive"
"io"
- "log"
"os"
+ "time"
)
const invalidHashString = "invalid"
@@ -28,3 +32,29 @@ func Sha256Sum(fileName string) (string, error) {
return string(h.Sum(nil)), nil
}
+
+func CustomRowID(payload uint64, timestamp time.Time) RowID {
+ var key [12]byte
+ binary.BigEndian.PutUint32(key[0:4], uint32(timestamp.Unix()))
+ binary.BigEndian.PutUint64(key[4:12], payload)
+
+ if oid, err := primitive.ObjectIDFromHex(hex.EncodeToString(key[:])); err == nil {
+ return oid
+ } else {
+ log.WithError(err).Warn("failed to create object id")
+ return primitive.NewObjectID()
+ }
+}
+
+func NewRowID() RowID {
+ return primitive.NewObjectID()
+}
+
+func EmptyRowID() RowID {
+ return [12]byte{}
+}
+
+func RowIDFromHex(hex string) (RowID, error) {
+ rowID, err := primitive.ObjectIDFromHex(hex)
+ return rowID, err
+}