From 7113463dead05631339fdab94de9440201c42489 Mon Sep 17 00:00:00 2001 From: Emiliano Ciavatta Date: Mon, 13 Apr 2020 12:29:00 +0200 Subject: Update rules --- connection_handler_test.go | 46 +++++++----- routes.go | 29 ++++++-- rules_manager.go | 175 +++++++++++++++++++++++++++------------------ 3 files changed, 161 insertions(+), 89 deletions(-) diff --git a/connection_handler_test.go b/connection_handler_test.go index aefa058..20eb041 100644 --- a/connection_handler_test.go +++ b/connection_handler_test.go @@ -96,7 +96,7 @@ func TestConnectionFactory(t *testing.T) { testInteraction := func(netFlow gopacket.Flow, transportFlow gopacket.Flow, otherSeenChan chan time.Time, completed chan bool) { - time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond) + time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond) stream := factory.New(netFlow, transportFlow) seen := time.Now() stream.Reassembled([]tcpassembly.Reassembly{{[]byte{}, 0, true, true, seen}}) @@ -106,25 +106,25 @@ func TestConnectionFactory(t *testing.T) { if netFlow == serverClientNetFlow { otherSeenChan <- seen return + } + + otherSeen, ok := <-otherSeenChan + require.True(t, ok) + + if seen.Before(otherSeen) { + startedAt = seen + closedAt = otherSeen } else { - otherSeen, ok := <-otherSeenChan - require.True(t, ok) - - if seen.Before(otherSeen) { - startedAt = seen - closedAt = otherSeen - } else { - startedAt = otherSeen - closedAt = seen - } + startedAt = otherSeen + closedAt = seen } close(otherSeenChan) var result Connection connectionFlow := StreamFlow{netFlow.Src(), netFlow.Dst(), transportFlow.Src(), transportFlow.Dst()} connectionID := wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), startedAt) - err = wrapper.Storage.Find(Connections).Context(wrapper.Context). - Filter(OrderedDocument{{"_id", connectionID}}).First(&result) + op := wrapper.Storage.Find(Connections).Context(wrapper.Context) + err := op.Filter(OrderedDocument{{"_id", connectionID}}).First(&result) require.NoError(t, err) assert.NotNil(t, result) @@ -140,7 +140,7 @@ func TestConnectionFactory(t *testing.T) { } completed := make(chan bool) - n := 3000 + n := 1000 for port := 40000; port < 40000+n; port++ { clientPort := layers.NewTCPPortEndpoint(layers.TCPPort(port)) @@ -154,7 +154,7 @@ func TestConnectionFactory(t *testing.T) { go testInteraction(serverClientNetFlow, serverClientTransportFlow, otherSeenChan, completed) } - timeout := time.Tick(1 * time.Second) + timeout := time.Tick(10 * time.Second) for i := 0; i < n; i++ { select { case <- completed: @@ -178,8 +178,20 @@ func (rm TestRuleManager) LoadRules() error { return nil } -func (rm TestRuleManager) AddRule(_ context.Context, _ Rule) (string, error) { - return "", nil +func (rm TestRuleManager) AddRule(_ context.Context, _ Rule) (RowID, error) { + return RowID{}, nil +} + +func (rm TestRuleManager) GetRule(_ RowID) (Rule, bool) { + return Rule{}, false +} + +func (rm TestRuleManager) UpdateRule(_ context.Context, _ Rule) bool { + return false +} + +func (rm TestRuleManager) GetRules() []Rule { + return nil } func (rm TestRuleManager) FillWithMatchedRules(_ *Connection, _ map[uint][]PatternSlice, _ map[uint][]PatternSlice) { diff --git a/routes.go b/routes.go index 3759382..37088e2 100644 --- a/routes.go +++ b/routes.go @@ -3,8 +3,8 @@ package main import ( "fmt" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" log "github.com/sirupsen/logrus" + "github.com/go-playground/validator/v10" "net/http" ) @@ -13,17 +13,38 @@ func ApplicationRoutes(engine *gin.Engine) { api := engine.Group("/api") { - api.POST("/rules", func(c *gin.Context) { + api.POST("/rule", func(c *gin.Context) { var rule Rule + //data, _ := c.GetRawData() + // + //var json = jsoniter.ConfigCompatibleWithStandardLibrary + //err := json.Unmarshal(data, &filter) + // + //if err != nil { + // log.WithError(err).Error("failed to unmarshal") + // c.String(500, "failed to unmarshal") + //} + // + //err = validator.New().Struct(filter) + //log.WithError(err).WithField("filter", filter).Error("aaaa") + //c.String(200, "ok") + if err := c.ShouldBindJSON(&rule); err != nil { - for _, fieldErr := range err.(validator.ValidationErrors) { + validationErrors, ok := err.(validator.ValidationErrors) + if !ok { + log.WithError(err).WithField("rule", rule).Error("oops") + c.JSON(http.StatusBadRequest, gin.H{}) + return + } + + for _, fieldErr := range validationErrors { log.Println(fieldErr) c.JSON(http.StatusBadRequest, gin.H{ "error": fmt.Sprintf("field '%v' does not respect the %v(%v) rule", fieldErr.Field(), fieldErr.Tag(), fieldErr.Param()), }) - log.WithError(err).WithField("rule", rule).Panic("oops") + log.WithError(err).WithField("rule", rule).Error("oops") return // exit on first error } } diff --git a/rules_manager.go b/rules_manager.go index 0750d53..388aeee 100644 --- a/rules_manager.go +++ b/rules_manager.go @@ -6,45 +6,44 @@ import ( "errors" "fmt" "github.com/flier/gohs/hyperscan" + "github.com/go-playground/validator/v10" log "github.com/sirupsen/logrus" "sync" "time" ) type RegexFlags struct { - Caseless bool `json:"caseless"` // Set case-insensitive matching. - DotAll bool `json:"dot_all"` // Matching a `.` will not exclude newlines. - MultiLine bool `json:"multi_line"` // Set multi-line anchoring. - SingleMatch bool `json:"single_match"` // Set single-match only mode. - Utf8Mode bool `json:"utf_8_mode"` // Enable UTF-8 mode for this expression. - UnicodeProperty bool `json:"unicode_property"` // Enable Unicode property support for this expression + Caseless bool `json:"caseless" bson:"caseless,omitempty"` // Set case-insensitive matching. + DotAll bool `json:"dot_all" bson:"dot_all,omitempty"` // Matching a `.` will not exclude newlines. + MultiLine bool `json:"multi_line" bson:"multi_line,omitempty"` // Set multi-line anchoring. + SingleMatch bool `json:"single_match" bson:"single_match,omitempty"` // Set single-match only mode. + Utf8Mode bool `json:"utf_8_mode" bson:"utf_8_mode,omitempty"` // Enable UTF-8 mode for this expression. + UnicodeProperty bool `json:"unicode_property" bson:"unicode_property,omitempty"` // Enable Unicode property support for this expression } type Pattern struct { - Regex string `json:"regex"` - Flags RegexFlags `json:"flags"` - MinOccurrences int `json:"min_occurrences"` - MaxOccurrences int `json:"max_occurrences"` + Regex string `json:"regex" binding:"min=1" bson:"regex"` + Flags RegexFlags `json:"flags" bson:"flags,omitempty"` + MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"` + MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"` internalID int compiledPattern *hyperscan.Pattern } type Filter struct { - ServicePort int - ClientAddress string - ClientPort int - MinDuration int - MaxDuration int - MinPackets int - MaxPackets int - MinSize int - MaxSize int + ServicePort uint16 `json:"service_port" bson:"service_port,omitempty"` + ClientAddress string `json:"client_address" binding:"ip_addr" bson:"client_address,omitempty"` + ClientPort uint16 `json:"client_port" bson:"client_port,omitempty"` + MinDuration uint `json:"min_duration" bson:"min_duration,omitempty"` + MaxDuration uint `json:"max_duration" binding:"omitempty,gtefield=MinDuration" bson:"max_duration,omitempty"` + MinBytes uint `json:"min_bytes" bson:"min_bytes,omitempty"` + MaxBytes uint `json:"max_bytes" binding:"omitempty,gtefield=MinBytes" bson:"max_bytes,omitempty"` } type Rule struct { - ID RowID `json:"-" bson:"_id,omitempty"` - Name string `json:"name" binding:"required,min=3" bson:"name"` - Color string `json:"color" binding:"required,hexcolor" bson:"color"` + ID RowID `json:"id" bson:"_id,omitempty"` + Name string `json:"name" binding:"min=3" bson:"name"` + Color string `json:"color" binding:"hexcolor" bson:"color"` Notes string `json:"notes" bson:"notes,omitempty"` Enabled bool `json:"enabled" bson:"enabled"` Patterns []Pattern `json:"patterns" binding:"required,min=1" bson:"patterns"` @@ -53,38 +52,44 @@ type Rule struct { } type RulesDatabase struct { - database hyperscan.StreamDatabase + database hyperscan.StreamDatabase databaseSize int - version RowID + version RowID } type RulesManager interface { LoadRules() error - AddRule(context context.Context, rule Rule) (string, error) + AddRule(context context.Context, rule Rule) (RowID, error) + GetRule(id RowID) (Rule, bool) + UpdateRule(context context.Context, rule Rule) bool + GetRules() []Rule FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice) DatabaseUpdateChannel() chan RulesDatabase } type rulesManagerImpl struct { storage Storage - rules map[string]Rule + rules map[RowID]Rule rulesByName map[string]Rule - ruleIndex int patterns map[string]Pattern - mPatterns sync.Mutex + mutex sync.Mutex databaseUpdated chan RulesDatabase + validate *validator.Validate } func NewRulesManager(storage Storage) RulesManager { return &rulesManagerImpl{ - storage: storage, - rules: make(map[string]Rule), - patterns: make(map[string]Pattern), - mPatterns: sync.Mutex{}, + storage: storage, + rules: make(map[RowID]Rule), + rulesByName: make(map[string]Rule), + patterns: make(map[string]Pattern), + mutex: sync.Mutex{}, + databaseUpdated: make(chan RulesDatabase), + validate: validator.New(), } } -func (rm rulesManagerImpl) LoadRules() error { +func (rm *rulesManagerImpl) LoadRules() error { var rules []Rule if err := rm.storage.Find(Rules).Sort("_id", true).All(&rules); err != nil { return err @@ -96,48 +101,96 @@ func (rm rulesManagerImpl) LoadRules() error { } } - rm.ruleIndex = len(rules) return rm.generateDatabase(rules[len(rules)-1].ID) } -func (rm rulesManagerImpl) AddRule(context context.Context, rule Rule) (string, error) { - rm.mPatterns.Lock() +func (rm *rulesManagerImpl) AddRule(context context.Context, rule Rule) (RowID, error) { + rm.mutex.Lock() - rule.ID = rm.storage.NewCustomRowID(uint64(rm.ruleIndex), time.Now()) + rule.ID = rm.storage.NewCustomRowID(uint64(len(rm.rules)), time.Now()) rule.Enabled = true if err := rm.validateAndAddRuleLocal(&rule); err != nil { - rm.mPatterns.Unlock() - return "", err + rm.mutex.Unlock() + return rule.ID, err } if err := rm.generateDatabase(rule.ID); err != nil { - rm.mPatterns.Unlock() + rm.mutex.Unlock() log.WithError(err).WithField("rule", rule).Panic("failed to generate database") } - rm.mPatterns.Unlock() + rm.mutex.Unlock() if _, err := rm.storage.Insert(Rules).Context(context).One(rule); err != nil { log.WithError(err).WithField("rule", rule).Panic("failed to insert rule on database") } - return rule.ID.Hex(), nil + return rule.ID, nil } -func (rm rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error { +func (rm *rulesManagerImpl) GetRule(id RowID) (Rule, bool) { + rule, isPresent := rm.rules[id] + return rule, isPresent +} + +func (rm *rulesManagerImpl) UpdateRule(context context.Context, rule Rule) bool { + updated, err := rm.storage.Update(Rules).Context(context).Filter(OrderedDocument{{"_id", rule.ID}}). + One(UnorderedDocument{"name": rule.Name, "color": rule.Color}) + if err != nil { + log.WithError(err).WithField("rule", rule).Panic("failed to update rule on database") + } + + if updated { + rm.mutex.Lock() + newRule := rm.rules[rule.ID] + newRule.Name = rule.Name + newRule.Color = rule.Color + + delete(rm.rulesByName, newRule.Name) + rm.rulesByName[rule.Name] = newRule + rm.rules[rule.ID] = newRule + rm.mutex.Unlock() + } + + return updated +} + +func (rm *rulesManagerImpl) GetRules() []Rule { + rules := make([]Rule, 0, len(rm.rules)) + + for _, rule := range rm.rules { + rules = append(rules, rule) + } + + return rules +} + +func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, + serverMatches map[uint][]PatternSlice) { +} + +func (rm *rulesManagerImpl) DatabaseUpdateChannel() chan RulesDatabase { + return rm.databaseUpdated +} + +func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error { if _, alreadyPresent := rm.rulesByName[rule.Name]; alreadyPresent { return errors.New("rule name must be unique") } newPatterns := make(map[string]Pattern) for i, pattern := range rule.Patterns { + if err := rm.validate.Struct(pattern); err != nil { + return err + } + hash := pattern.Hash() if existingPattern, isPresent := rm.patterns[hash]; isPresent { rule.Patterns[i] = existingPattern continue } - err := pattern.BuildPattern() - if err != nil { + + if err := pattern.BuildPattern(); err != nil { return err } pattern.internalID = len(rm.patterns) + len(newPatterns) @@ -148,18 +201,16 @@ func (rm rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error { rm.patterns[key] = value } - rm.rules[rule.ID.Hex()] = *rule + rm.rules[rule.ID] = *rule rm.rulesByName[rule.Name] = *rule return nil } -func (rm rulesManagerImpl) generateDatabase(version RowID) error { - patterns := make([]*hyperscan.Pattern, len(rm.patterns)) - var i int +func (rm *rulesManagerImpl) generateDatabase(version RowID) error { + patterns := make([]*hyperscan.Pattern, 0, len(rm.patterns)) for _, pattern := range rm.patterns { - patterns[i] = pattern.compiledPattern - i++ + patterns = append(patterns, pattern.compiledPattern) } database, err := hyperscan.NewStreamDatabase(patterns...) if err != nil { @@ -167,31 +218,17 @@ func (rm rulesManagerImpl) generateDatabase(version RowID) error { } rm.databaseUpdated <- RulesDatabase{ - database: database, + database: database, databaseSize: len(patterns), - version: version, + version: version, } return nil } -func (rm rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, - serverMatches map[uint][]PatternSlice) { -} - -func (rm rulesManagerImpl) DatabaseUpdateChannel() chan RulesDatabase { - return rm.databaseUpdated -} - -func (p Pattern) BuildPattern() error { +func (p *Pattern) BuildPattern() error { if p.compiledPattern != nil { return nil } - if p.MinOccurrences <= 0 { - return errors.New("min_occurrences can't be lower than zero") - } - if p.MaxOccurrences != -1 && p.MinOccurrences < p.MinOccurrences { - return errors.New("max_occurrences can't be lower than min_occurrences") - } hp, err := hyperscan.ParsePattern(fmt.Sprintf("/%s/", p.Regex)) if err != nil { @@ -221,6 +258,8 @@ func (p Pattern) BuildPattern() error { return errors.New("can't validate the pattern") } + p.compiledPattern = hp + return nil } -- cgit v1.2.3-70-g09d2