From a56a4e391d541ae05de0203f3d493edc3b04681d Mon Sep 17 00:00:00 2001 From: Emiliano Ciavatta Date: Mon, 13 Apr 2020 17:12:35 +0200 Subject: Add AddRoute tests --- rules_manager.go | 103 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 57 insertions(+), 46 deletions(-) (limited to 'rules_manager.go') diff --git a/rules_manager.go b/rules_manager.go index 388aeee..57f8768 100644 --- a/rules_manager.go +++ b/rules_manager.go @@ -2,7 +2,6 @@ package main import ( "context" - "crypto/sha256" "errors" "fmt" "github.com/flier/gohs/hyperscan" @@ -12,6 +11,10 @@ import ( "time" ) +const DirectionBoth = 0 +const DirectionToServer = 1 +const DirectionToClient = 2 + type RegexFlags struct { Caseless bool `json:"caseless" bson:"caseless,omitempty"` // Set case-insensitive matching. DotAll bool `json:"dot_all" bson:"dot_all,omitempty"` // Matching a `.` will not exclude newlines. @@ -26,13 +29,13 @@ type Pattern struct { Flags RegexFlags `json:"flags" bson:"flags,omitempty"` MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"` MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"` + Direction uint8 `json:"direction" binding:"omitempty,max=2" bson:"direction,omitempty"` internalID int - compiledPattern *hyperscan.Pattern } type Filter struct { ServicePort uint16 `json:"service_port" bson:"service_port,omitempty"` - ClientAddress string `json:"client_address" binding:"ip_addr" bson:"client_address,omitempty"` + ClientAddress string `json:"client_address" binding:"omitempty,ip_addr" bson:"client_address,omitempty"` ClientPort uint16 `json:"client_port" bson:"client_port,omitempty"` MinDuration uint `json:"min_duration" bson:"min_duration,omitempty"` MaxDuration uint `json:"max_duration" binding:"omitempty,gtefield=MinDuration" bson:"max_duration,omitempty"` @@ -46,7 +49,7 @@ type Rule struct { Color string `json:"color" binding:"hexcolor" bson:"color"` Notes string `json:"notes" bson:"notes,omitempty"` Enabled bool `json:"enabled" bson:"enabled"` - Patterns []Pattern `json:"patterns" binding:"required,min=1" bson:"patterns"` + Patterns []Pattern `json:"patterns" bson:"patterns"` Filter Filter `json:"filter" bson:"filter,omitempty"` Version int64 `json:"version" bson:"version"` } @@ -61,7 +64,7 @@ type RulesManager interface { LoadRules() error AddRule(context context.Context, rule Rule) (RowID, error) GetRule(id RowID) (Rule, bool) - UpdateRule(context context.Context, rule Rule) bool + UpdateRule(context context.Context, id RowID, rule Rule) bool GetRules() []Rule FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice) DatabaseUpdateChannel() chan RulesDatabase @@ -71,7 +74,8 @@ type rulesManagerImpl struct { storage Storage rules map[RowID]Rule rulesByName map[string]Rule - patterns map[string]Pattern + patterns []*hyperscan.Pattern + patternsIds map[string]int mutex sync.Mutex databaseUpdated chan RulesDatabase validate *validator.Validate @@ -82,9 +86,10 @@ func NewRulesManager(storage Storage) RulesManager { storage: storage, rules: make(map[RowID]Rule), rulesByName: make(map[string]Rule), - patterns: make(map[string]Pattern), + patterns: make([]*hyperscan.Pattern, 0), + patternsIds: make(map[string]int), mutex: sync.Mutex{}, - databaseUpdated: make(chan RulesDatabase), + databaseUpdated: make(chan RulesDatabase, 1), validate: validator.New(), } } @@ -107,12 +112,12 @@ func (rm *rulesManagerImpl) LoadRules() error { func (rm *rulesManagerImpl) AddRule(context context.Context, rule Rule) (RowID, error) { rm.mutex.Lock() - rule.ID = rm.storage.NewCustomRowID(uint64(len(rm.rules)), time.Now()) + rule.ID = CustomRowID(uint64(len(rm.rules)), time.Now()) rule.Enabled = true if err := rm.validateAndAddRuleLocal(&rule); err != nil { rm.mutex.Unlock() - return rule.ID, err + return EmptyRowID(), err } if err := rm.generateDatabase(rule.ID); err != nil { @@ -133,8 +138,13 @@ func (rm *rulesManagerImpl) GetRule(id RowID) (Rule, bool) { return rule, isPresent } -func (rm *rulesManagerImpl) UpdateRule(context context.Context, rule Rule) bool { - updated, err := rm.storage.Update(Rules).Context(context).Filter(OrderedDocument{{"_id", rule.ID}}). +func (rm *rulesManagerImpl) UpdateRule(context context.Context, id RowID, rule Rule) bool { + newRule, isPresent := rm.rules[id] + if !isPresent { + return false + } + + updated, err := rm.storage.Update(Rules).Context(context).Filter(OrderedDocument{{"_id", id}}). One(UnorderedDocument{"name": rule.Name, "color": rule.Color}) if err != nil { log.WithError(err).WithField("rule", rule).Panic("failed to update rule on database") @@ -142,7 +152,6 @@ func (rm *rulesManagerImpl) UpdateRule(context context.Context, rule Rule) bool if updated { rm.mutex.Lock() - newRule := rm.rules[rule.ID] newRule.Name = rule.Name newRule.Color = rule.Color @@ -165,6 +174,19 @@ func (rm *rulesManagerImpl) GetRules() []Rule { return rules } +func (rm *rulesManagerImpl) SetFlag(context context.Context, flagRegex string) error { + _, err := rm.AddRule(context, Rule{ + Name: "flag", + Color: "#ff0000", + Notes: "Mark connections where the flag is stolen", + Patterns: []Pattern{ + {Regex: flagRegex, Direction: DirectionToClient}, + }, + }) + + return err +} + func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice) { } @@ -178,27 +200,31 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error { return errors.New("rule name must be unique") } - newPatterns := make(map[string]Pattern) + newPatterns := make([]*hyperscan.Pattern, 0, len(rule.Patterns)) for i, pattern := range rule.Patterns { if err := rm.validate.Struct(pattern); err != nil { return err } - hash := pattern.Hash() - if existingPattern, isPresent := rm.patterns[hash]; isPresent { - rule.Patterns[i] = existingPattern + compiledPattern, err := pattern.BuildPattern() + if err != nil { + return err + } + if existingPattern, isPresent := rm.patternsIds[compiledPattern.String()]; isPresent { + rule.Patterns[i].internalID = existingPattern continue } - if err := pattern.BuildPattern(); err != nil { - return err - } - pattern.internalID = len(rm.patterns) + len(newPatterns) - newPatterns[hash] = pattern + id := len(rm.patternsIds) + len(newPatterns) + rule.Patterns[i].internalID = id + compiledPattern.Id = id + newPatterns = append(newPatterns, compiledPattern) } - for key, value := range newPatterns { - rm.patterns[key] = value + startId := len(rm.patterns) + for id, pattern := range newPatterns { + rm.patterns = append(rm.patterns, pattern) + rm.patternsIds[pattern.String()] = startId + id } rm.rules[rule.ID] = *rule @@ -208,31 +234,24 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error { } func (rm *rulesManagerImpl) generateDatabase(version RowID) error { - patterns := make([]*hyperscan.Pattern, 0, len(rm.patterns)) - for _, pattern := range rm.patterns { - patterns = append(patterns, pattern.compiledPattern) - } - database, err := hyperscan.NewStreamDatabase(patterns...) + database, err := hyperscan.NewStreamDatabase(rm.patterns...) if err != nil { return err } rm.databaseUpdated <- RulesDatabase{ database: database, - databaseSize: len(patterns), + databaseSize: len(rm.patterns), version: version, } + return nil } -func (p *Pattern) BuildPattern() error { - if p.compiledPattern != nil { - return nil - } - +func (p *Pattern) BuildPattern() (*hyperscan.Pattern, error) { hp, err := hyperscan.ParsePattern(fmt.Sprintf("/%s/", p.Regex)) if err != nil { - return err + return nil, err } if p.Flags.Caseless { @@ -255,16 +274,8 @@ func (p *Pattern) BuildPattern() error { } if !hp.IsValid() { - return errors.New("can't validate the pattern") + return nil, errors.New("can't validate the pattern") } - p.compiledPattern = hp - - return nil -} - -func (p Pattern) Hash() string { - hash := sha256.New() - hash.Write([]byte(fmt.Sprintf("%s|%v|%v|%v", p.Regex, p.Flags, p.MinOccurrences, p.MaxOccurrences))) - return fmt.Sprintf("%x", hash.Sum(nil)) + return hp, nil } -- cgit v1.2.3-70-g09d2