aboutsummaryrefslogtreecommitdiff
path: root/rules_manager.go
diff options
context:
space:
mode:
authorEmiliano Ciavatta2020-04-13 15:12:35 +0000
committerEmiliano Ciavatta2020-04-13 15:12:35 +0000
commita56a4e391d541ae05de0203f3d493edc3b04681d (patch)
treeab9344a650305aafb5afe552dc8cad63684de643 /rules_manager.go
parent7113463dead05631339fdab94de9440201c42489 (diff)
Add AddRoute tests
Diffstat (limited to 'rules_manager.go')
-rw-r--r--rules_manager.go103
1 files changed, 57 insertions, 46 deletions
diff --git a/rules_manager.go b/rules_manager.go
index 388aeee..57f8768 100644
--- a/rules_manager.go
+++ b/rules_manager.go
@@ -2,7 +2,6 @@ package main
import (
"context"
- "crypto/sha256"
"errors"
"fmt"
"github.com/flier/gohs/hyperscan"
@@ -12,6 +11,10 @@ import (
"time"
)
+const DirectionBoth = 0
+const DirectionToServer = 1
+const DirectionToClient = 2
+
type RegexFlags struct {
Caseless bool `json:"caseless" bson:"caseless,omitempty"` // Set case-insensitive matching.
DotAll bool `json:"dot_all" bson:"dot_all,omitempty"` // Matching a `.` will not exclude newlines.
@@ -26,13 +29,13 @@ type Pattern struct {
Flags RegexFlags `json:"flags" bson:"flags,omitempty"`
MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"`
MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"`
+ Direction uint8 `json:"direction" binding:"omitempty,max=2" bson:"direction,omitempty"`
internalID int
- compiledPattern *hyperscan.Pattern
}
type Filter struct {
ServicePort uint16 `json:"service_port" bson:"service_port,omitempty"`
- ClientAddress string `json:"client_address" binding:"ip_addr" bson:"client_address,omitempty"`
+ ClientAddress string `json:"client_address" binding:"omitempty,ip_addr" bson:"client_address,omitempty"`
ClientPort uint16 `json:"client_port" bson:"client_port,omitempty"`
MinDuration uint `json:"min_duration" bson:"min_duration,omitempty"`
MaxDuration uint `json:"max_duration" binding:"omitempty,gtefield=MinDuration" bson:"max_duration,omitempty"`
@@ -46,7 +49,7 @@ type Rule struct {
Color string `json:"color" binding:"hexcolor" bson:"color"`
Notes string `json:"notes" bson:"notes,omitempty"`
Enabled bool `json:"enabled" bson:"enabled"`
- Patterns []Pattern `json:"patterns" binding:"required,min=1" bson:"patterns"`
+ Patterns []Pattern `json:"patterns" bson:"patterns"`
Filter Filter `json:"filter" bson:"filter,omitempty"`
Version int64 `json:"version" bson:"version"`
}
@@ -61,7 +64,7 @@ type RulesManager interface {
LoadRules() error
AddRule(context context.Context, rule Rule) (RowID, error)
GetRule(id RowID) (Rule, bool)
- UpdateRule(context context.Context, rule Rule) bool
+ UpdateRule(context context.Context, id RowID, rule Rule) bool
GetRules() []Rule
FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice)
DatabaseUpdateChannel() chan RulesDatabase
@@ -71,7 +74,8 @@ type rulesManagerImpl struct {
storage Storage
rules map[RowID]Rule
rulesByName map[string]Rule
- patterns map[string]Pattern
+ patterns []*hyperscan.Pattern
+ patternsIds map[string]int
mutex sync.Mutex
databaseUpdated chan RulesDatabase
validate *validator.Validate
@@ -82,9 +86,10 @@ func NewRulesManager(storage Storage) RulesManager {
storage: storage,
rules: make(map[RowID]Rule),
rulesByName: make(map[string]Rule),
- patterns: make(map[string]Pattern),
+ patterns: make([]*hyperscan.Pattern, 0),
+ patternsIds: make(map[string]int),
mutex: sync.Mutex{},
- databaseUpdated: make(chan RulesDatabase),
+ databaseUpdated: make(chan RulesDatabase, 1),
validate: validator.New(),
}
}
@@ -107,12 +112,12 @@ func (rm *rulesManagerImpl) LoadRules() error {
func (rm *rulesManagerImpl) AddRule(context context.Context, rule Rule) (RowID, error) {
rm.mutex.Lock()
- rule.ID = rm.storage.NewCustomRowID(uint64(len(rm.rules)), time.Now())
+ rule.ID = CustomRowID(uint64(len(rm.rules)), time.Now())
rule.Enabled = true
if err := rm.validateAndAddRuleLocal(&rule); err != nil {
rm.mutex.Unlock()
- return rule.ID, err
+ return EmptyRowID(), err
}
if err := rm.generateDatabase(rule.ID); err != nil {
@@ -133,8 +138,13 @@ func (rm *rulesManagerImpl) GetRule(id RowID) (Rule, bool) {
return rule, isPresent
}
-func (rm *rulesManagerImpl) UpdateRule(context context.Context, rule Rule) bool {
- updated, err := rm.storage.Update(Rules).Context(context).Filter(OrderedDocument{{"_id", rule.ID}}).
+func (rm *rulesManagerImpl) UpdateRule(context context.Context, id RowID, rule Rule) bool {
+ newRule, isPresent := rm.rules[id]
+ if !isPresent {
+ return false
+ }
+
+ updated, err := rm.storage.Update(Rules).Context(context).Filter(OrderedDocument{{"_id", id}}).
One(UnorderedDocument{"name": rule.Name, "color": rule.Color})
if err != nil {
log.WithError(err).WithField("rule", rule).Panic("failed to update rule on database")
@@ -142,7 +152,6 @@ func (rm *rulesManagerImpl) UpdateRule(context context.Context, rule Rule) bool
if updated {
rm.mutex.Lock()
- newRule := rm.rules[rule.ID]
newRule.Name = rule.Name
newRule.Color = rule.Color
@@ -165,6 +174,19 @@ func (rm *rulesManagerImpl) GetRules() []Rule {
return rules
}
+func (rm *rulesManagerImpl) SetFlag(context context.Context, flagRegex string) error {
+ _, err := rm.AddRule(context, Rule{
+ Name: "flag",
+ Color: "#ff0000",
+ Notes: "Mark connections where the flag is stolen",
+ Patterns: []Pattern{
+ {Regex: flagRegex, Direction: DirectionToClient},
+ },
+ })
+
+ return err
+}
+
func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice,
serverMatches map[uint][]PatternSlice) {
}
@@ -178,27 +200,31 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error {
return errors.New("rule name must be unique")
}
- newPatterns := make(map[string]Pattern)
+ newPatterns := make([]*hyperscan.Pattern, 0, len(rule.Patterns))
for i, pattern := range rule.Patterns {
if err := rm.validate.Struct(pattern); err != nil {
return err
}
- hash := pattern.Hash()
- if existingPattern, isPresent := rm.patterns[hash]; isPresent {
- rule.Patterns[i] = existingPattern
+ compiledPattern, err := pattern.BuildPattern()
+ if err != nil {
+ return err
+ }
+ if existingPattern, isPresent := rm.patternsIds[compiledPattern.String()]; isPresent {
+ rule.Patterns[i].internalID = existingPattern
continue
}
- if err := pattern.BuildPattern(); err != nil {
- return err
- }
- pattern.internalID = len(rm.patterns) + len(newPatterns)
- newPatterns[hash] = pattern
+ id := len(rm.patternsIds) + len(newPatterns)
+ rule.Patterns[i].internalID = id
+ compiledPattern.Id = id
+ newPatterns = append(newPatterns, compiledPattern)
}
- for key, value := range newPatterns {
- rm.patterns[key] = value
+ startId := len(rm.patterns)
+ for id, pattern := range newPatterns {
+ rm.patterns = append(rm.patterns, pattern)
+ rm.patternsIds[pattern.String()] = startId + id
}
rm.rules[rule.ID] = *rule
@@ -208,31 +234,24 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error {
}
func (rm *rulesManagerImpl) generateDatabase(version RowID) error {
- patterns := make([]*hyperscan.Pattern, 0, len(rm.patterns))
- for _, pattern := range rm.patterns {
- patterns = append(patterns, pattern.compiledPattern)
- }
- database, err := hyperscan.NewStreamDatabase(patterns...)
+ database, err := hyperscan.NewStreamDatabase(rm.patterns...)
if err != nil {
return err
}
rm.databaseUpdated <- RulesDatabase{
database: database,
- databaseSize: len(patterns),
+ databaseSize: len(rm.patterns),
version: version,
}
+
return nil
}
-func (p *Pattern) BuildPattern() error {
- if p.compiledPattern != nil {
- return nil
- }
-
+func (p *Pattern) BuildPattern() (*hyperscan.Pattern, error) {
hp, err := hyperscan.ParsePattern(fmt.Sprintf("/%s/", p.Regex))
if err != nil {
- return err
+ return nil, err
}
if p.Flags.Caseless {
@@ -255,16 +274,8 @@ func (p *Pattern) BuildPattern() error {
}
if !hp.IsValid() {
- return errors.New("can't validate the pattern")
+ return nil, errors.New("can't validate the pattern")
}
- p.compiledPattern = hp
-
- return nil
-}
-
-func (p Pattern) Hash() string {
- hash := sha256.New()
- hash.Write([]byte(fmt.Sprintf("%s|%v|%v|%v", p.Regex, p.Flags, p.MinOccurrences, p.MaxOccurrences)))
- return fmt.Sprintf("%x", hash.Sum(nil))
+ return hp, nil
}