diff options
author | Emiliano Ciavatta | 2020-04-21 10:57:10 +0000 |
---|---|---|
committer | Emiliano Ciavatta | 2020-04-21 10:57:10 +0000 |
commit | 516712f670803979c65fd3db73dd2de9d7175139 (patch) | |
tree | 44e701aee56031ea3a731352bd639ae1b060f83c /rules_manager.go | |
parent | 324623884309e95f541f285faee48150988ec466 (diff) |
Add application_context
Diffstat (limited to 'rules_manager.go')
-rw-r--r-- | rules_manager.go | 58 |
1 files changed, 36 insertions, 22 deletions
diff --git a/rules_manager.go b/rules_manager.go index fdd3eaa..015d8f7 100644 --- a/rules_manager.go +++ b/rules_manager.go @@ -25,12 +25,12 @@ type RegexFlags struct { } type Pattern struct { - Regex string `json:"regex" binding:"required,min=1" bson:"regex"` - Flags RegexFlags `json:"flags" bson:"flags,omitempty"` - MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"` - MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"` - Direction uint8 `json:"direction" binding:"omitempty,max=2" bson:"direction,omitempty"` - internalID uint + Regex string `json:"regex" binding:"required,min=1" bson:"regex"` + Flags RegexFlags `json:"flags" bson:"flags,omitempty"` + MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"` + MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"` + Direction uint8 `json:"direction" binding:"omitempty,max=2" bson:"direction,omitempty"` + internalID uint } type Filter struct { @@ -61,12 +61,10 @@ type RulesDatabase struct { } type RulesManager interface { - LoadRules() error AddRule(context context.Context, rule Rule) (RowID, error) GetRule(id RowID) (Rule, bool) UpdateRule(context context.Context, id RowID, rule Rule) (bool, error) GetRules() []Rule - SetFlag(context context.Context, flagRegex string) error FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice) DatabaseUpdateChannel() chan RulesDatabase } @@ -82,8 +80,13 @@ type rulesManagerImpl struct { validate *validator.Validate } -func NewRulesManager(storage Storage) RulesManager { - return &rulesManagerImpl{ +func LoadRulesManager(storage Storage, flagRegex string) (RulesManager, error) { + var rules []Rule + if err := storage.Find(Rules).Sort("_id", true).All(&rules); err != nil { + return nil, err + } + + rulesManager := rulesManagerImpl{ storage: storage, rules: make(map[RowID]Rule), rulesByName: make(map[string]Rule), @@ -93,21 +96,32 @@ func NewRulesManager(storage Storage) RulesManager { databaseUpdated: make(chan RulesDatabase, 1), validate: validator.New(), } -} -func (rm *rulesManagerImpl) LoadRules() error { - var rules []Rule - if err := rm.storage.Find(Rules).Sort("_id", true).All(&rules); err != nil { - return err + for _, rule := range rules { + if err := rulesManager.validateAndAddRuleLocal(&rule); err != nil { + return nil, err + } } - for _, rule := range rules { - if err := rm.validateAndAddRuleLocal(&rule); err != nil { - log.WithError(err).WithField("rule", rule).Warn("failed to import rule") + // if there are no rules in database (e.g. first run), set flagRegex as first rule + if len(rulesManager.rules) == 0 { + if _, err := rulesManager.AddRule(context.Background(), Rule{ + Name: "flag", + Color: "#ff0000", + Notes: "Mark connections where the flag is stolen", + Patterns: []Pattern{ + {Regex: flagRegex, Direction: DirectionToClient}, + }, + }); err != nil { + return nil, err + } + } else { + if err := rulesManager.generateDatabase(rules[len(rules)-1].ID); err != nil { + return nil, err } } - return rm.generateDatabase(rules[len(rules)-1].ID) + return &rulesManager, nil } func (rm *rulesManagerImpl) AddRule(context context.Context, rule Rule) (RowID, error) { @@ -197,7 +211,7 @@ func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientM serverMatches map[uint][]PatternSlice) { rm.mutex.Lock() - filterFunctions := []func (rule Rule)bool { + filterFunctions := []func(rule Rule) bool{ func(rule Rule) bool { return rule.Filter.ClientAddress == "" || connection.SourceIP == rule.Filter.ClientAddress }, @@ -216,11 +230,11 @@ func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientM rule.Filter.MaxDuration }, func(rule Rule) bool { - return rule.Filter.MinBytes == 0 || uint(connection.ClientBytes + connection.ServerBytes) >= + return rule.Filter.MinBytes == 0 || uint(connection.ClientBytes+connection.ServerBytes) >= rule.Filter.MinBytes }, func(rule Rule) bool { - return rule.Filter.MaxBytes == 0 || uint(connection.ClientBytes + connection.ServerBytes) <= + return rule.Filter.MaxBytes == 0 || uint(connection.ClientBytes+connection.ServerBytes) <= rule.Filter.MinBytes }, } |