aboutsummaryrefslogtreecommitdiff
path: root/rules_manager.go
diff options
context:
space:
mode:
authorEmiliano Ciavatta2020-04-21 10:57:10 +0000
committerEmiliano Ciavatta2020-04-21 10:57:10 +0000
commit516712f670803979c65fd3db73dd2de9d7175139 (patch)
tree44e701aee56031ea3a731352bd639ae1b060f83c /rules_manager.go
parent324623884309e95f541f285faee48150988ec466 (diff)
Add application_context
Diffstat (limited to 'rules_manager.go')
-rw-r--r--rules_manager.go58
1 files changed, 36 insertions, 22 deletions
diff --git a/rules_manager.go b/rules_manager.go
index fdd3eaa..015d8f7 100644
--- a/rules_manager.go
+++ b/rules_manager.go
@@ -25,12 +25,12 @@ type RegexFlags struct {
}
type Pattern struct {
- Regex string `json:"regex" binding:"required,min=1" bson:"regex"`
- Flags RegexFlags `json:"flags" bson:"flags,omitempty"`
- MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"`
- MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"`
- Direction uint8 `json:"direction" binding:"omitempty,max=2" bson:"direction,omitempty"`
- internalID uint
+ Regex string `json:"regex" binding:"required,min=1" bson:"regex"`
+ Flags RegexFlags `json:"flags" bson:"flags,omitempty"`
+ MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"`
+ MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"`
+ Direction uint8 `json:"direction" binding:"omitempty,max=2" bson:"direction,omitempty"`
+ internalID uint
}
type Filter struct {
@@ -61,12 +61,10 @@ type RulesDatabase struct {
}
type RulesManager interface {
- LoadRules() error
AddRule(context context.Context, rule Rule) (RowID, error)
GetRule(id RowID) (Rule, bool)
UpdateRule(context context.Context, id RowID, rule Rule) (bool, error)
GetRules() []Rule
- SetFlag(context context.Context, flagRegex string) error
FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice)
DatabaseUpdateChannel() chan RulesDatabase
}
@@ -82,8 +80,13 @@ type rulesManagerImpl struct {
validate *validator.Validate
}
-func NewRulesManager(storage Storage) RulesManager {
- return &rulesManagerImpl{
+func LoadRulesManager(storage Storage, flagRegex string) (RulesManager, error) {
+ var rules []Rule
+ if err := storage.Find(Rules).Sort("_id", true).All(&rules); err != nil {
+ return nil, err
+ }
+
+ rulesManager := rulesManagerImpl{
storage: storage,
rules: make(map[RowID]Rule),
rulesByName: make(map[string]Rule),
@@ -93,21 +96,32 @@ func NewRulesManager(storage Storage) RulesManager {
databaseUpdated: make(chan RulesDatabase, 1),
validate: validator.New(),
}
-}
-func (rm *rulesManagerImpl) LoadRules() error {
- var rules []Rule
- if err := rm.storage.Find(Rules).Sort("_id", true).All(&rules); err != nil {
- return err
+ for _, rule := range rules {
+ if err := rulesManager.validateAndAddRuleLocal(&rule); err != nil {
+ return nil, err
+ }
}
- for _, rule := range rules {
- if err := rm.validateAndAddRuleLocal(&rule); err != nil {
- log.WithError(err).WithField("rule", rule).Warn("failed to import rule")
+ // if there are no rules in database (e.g. first run), set flagRegex as first rule
+ if len(rulesManager.rules) == 0 {
+ if _, err := rulesManager.AddRule(context.Background(), Rule{
+ Name: "flag",
+ Color: "#ff0000",
+ Notes: "Mark connections where the flag is stolen",
+ Patterns: []Pattern{
+ {Regex: flagRegex, Direction: DirectionToClient},
+ },
+ }); err != nil {
+ return nil, err
+ }
+ } else {
+ if err := rulesManager.generateDatabase(rules[len(rules)-1].ID); err != nil {
+ return nil, err
}
}
- return rm.generateDatabase(rules[len(rules)-1].ID)
+ return &rulesManager, nil
}
func (rm *rulesManagerImpl) AddRule(context context.Context, rule Rule) (RowID, error) {
@@ -197,7 +211,7 @@ func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientM
serverMatches map[uint][]PatternSlice) {
rm.mutex.Lock()
- filterFunctions := []func (rule Rule)bool {
+ filterFunctions := []func(rule Rule) bool{
func(rule Rule) bool {
return rule.Filter.ClientAddress == "" || connection.SourceIP == rule.Filter.ClientAddress
},
@@ -216,11 +230,11 @@ func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientM
rule.Filter.MaxDuration
},
func(rule Rule) bool {
- return rule.Filter.MinBytes == 0 || uint(connection.ClientBytes + connection.ServerBytes) >=
+ return rule.Filter.MinBytes == 0 || uint(connection.ClientBytes+connection.ServerBytes) >=
rule.Filter.MinBytes
},
func(rule Rule) bool {
- return rule.Filter.MaxBytes == 0 || uint(connection.ClientBytes + connection.ServerBytes) <=
+ return rule.Filter.MaxBytes == 0 || uint(connection.ClientBytes+connection.ServerBytes) <=
rule.Filter.MinBytes
},
}