diff options
-rw-r--r-- | connection_handler_test.go | 4 | ||||
-rw-r--r-- | routes.go | 2 | ||||
-rw-r--r-- | rules_manager.go | 87 | ||||
-rw-r--r-- | rules_manager_test.go | 117 |
4 files changed, 185 insertions, 25 deletions
diff --git a/connection_handler_test.go b/connection_handler_test.go index 9fe41af..7b00f6e 100644 --- a/connection_handler_test.go +++ b/connection_handler_test.go @@ -194,6 +194,10 @@ func (rm TestRulesManager) GetRules() []Rule { return nil } +func (rm TestRulesManager) SetFlag(_ context.Context, _ string) error { + return nil +} + func (rm TestRulesManager) FillWithMatchedRules(_ *Connection, _ map[uint][]PatternSlice, _ map[uint][]PatternSlice) { } @@ -95,7 +95,7 @@ func badRequest(c *gin.Context, err error) { } func unprocessableEntity(c *gin.Context, err error) { - c.JSON(http.StatusOK, UnorderedDocument{"result": "error", "error": err.Error()}) + c.JSON(http.StatusUnprocessableEntity, UnorderedDocument{"result": "error", "error": err.Error()}) } func notFound(c *gin.Context, obj interface{}) { diff --git a/rules_manager.go b/rules_manager.go index 89b6153..fdd3eaa 100644 --- a/rules_manager.go +++ b/rules_manager.go @@ -25,17 +25,17 @@ type RegexFlags struct { } type Pattern struct { - Regex string `json:"regex" binding:"min=1" bson:"regex"` + Regex string `json:"regex" binding:"required,min=1" bson:"regex"` Flags RegexFlags `json:"flags" bson:"flags,omitempty"` MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"` MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"` Direction uint8 `json:"direction" binding:"omitempty,max=2" bson:"direction,omitempty"` - internalID int + internalID uint } type Filter struct { ServicePort uint16 `json:"service_port" bson:"service_port,omitempty"` - ClientAddress string `json:"client_address" binding:"omitempty,ip_addr" bson:"client_address,omitempty"` + ClientAddress string `json:"client_address" binding:"omitempty,ip" bson:"client_address,omitempty"` ClientPort uint16 `json:"client_port" bson:"client_port,omitempty"` MinDuration uint `json:"min_duration" bson:"min_duration,omitempty"` MaxDuration uint `json:"max_duration" binding:"omitempty,gtefield=MinDuration" bson:"max_duration,omitempty"` @@ -66,6 +66,7 @@ type RulesManager interface { GetRule(id RowID) (Rule, bool) UpdateRule(context context.Context, id RowID, rule Rule) (bool, error) GetRules() []Rule + SetFlag(context context.Context, flagRegex string) error FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice) DatabaseUpdateChannel() chan RulesDatabase } @@ -75,7 +76,7 @@ type rulesManagerImpl struct { rules map[RowID]Rule rulesByName map[string]Rule patterns []*hyperscan.Pattern - patternsIds map[string]int + patternsIds map[string]uint mutex sync.Mutex databaseUpdated chan RulesDatabase validate *validator.Validate @@ -87,7 +88,7 @@ func NewRulesManager(storage Storage) RulesManager { rules: make(map[RowID]Rule), rulesByName: make(map[string]Rule), patterns: make([]*hyperscan.Pattern, 0), - patternsIds: make(map[string]int), + patternsIds: make(map[string]uint), mutex: sync.Mutex{}, databaseUpdated: make(chan RulesDatabase, 1), validate: validator.New(), @@ -194,6 +195,78 @@ func (rm *rulesManagerImpl) SetFlag(context context.Context, flagRegex string) e func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice) { + rm.mutex.Lock() + + filterFunctions := []func (rule Rule)bool { + func(rule Rule) bool { + return rule.Filter.ClientAddress == "" || connection.SourceIP == rule.Filter.ClientAddress + }, + func(rule Rule) bool { + return rule.Filter.ClientPort == 0 || connection.SourcePort == rule.Filter.ClientPort + }, + func(rule Rule) bool { + return rule.Filter.ServicePort == 0 || connection.DestinationPort == rule.Filter.ServicePort + }, + func(rule Rule) bool { + return rule.Filter.MinDuration == 0 || uint(connection.ClosedAt.Sub(connection.StartedAt).Milliseconds()) >= + rule.Filter.MinDuration + }, + func(rule Rule) bool { + return rule.Filter.MaxDuration == 0 || uint(connection.ClosedAt.Sub(connection.StartedAt).Milliseconds()) <= + rule.Filter.MaxDuration + }, + func(rule Rule) bool { + return rule.Filter.MinBytes == 0 || uint(connection.ClientBytes + connection.ServerBytes) >= + rule.Filter.MinBytes + }, + func(rule Rule) bool { + return rule.Filter.MaxBytes == 0 || uint(connection.ClientBytes + connection.ServerBytes) <= + rule.Filter.MinBytes + }, + } + + connection.MatchedRules = make([]RowID, 0) + for _, rule := range rm.rules { + matching := true + for _, f := range filterFunctions { + if !f(rule) { + matching = false + break + } + } + + for _, p := range rule.Patterns { + checkOccurrences := func(occurrences []PatternSlice) bool { + return (p.MinOccurrences == 0 || uint(len(occurrences)) >= p.MinOccurrences) && + (p.MaxOccurrences == 0 || uint(len(occurrences)) <= p.MaxOccurrences) + } + clientOccurrences, clientPresent := clientMatches[p.internalID] + serverOccurrences, serverPresent := serverMatches[p.internalID] + + if p.Direction == DirectionToServer { + if !clientPresent || !checkOccurrences(clientOccurrences) { + matching = false + break + } + } else if p.Direction == DirectionToClient { + if !serverPresent || !checkOccurrences(serverOccurrences) { + matching = false + break + } + } else { + if !(clientPresent || serverPresent) || !checkOccurrences(append(clientOccurrences, serverOccurrences...)) { + matching = false + break + } + } + } + + if matching { + connection.MatchedRules = append(connection.MatchedRules, rule.ID) + } + } + + rm.mutex.Unlock() } func (rm *rulesManagerImpl) DatabaseUpdateChannel() chan RulesDatabase { @@ -226,7 +299,7 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error { } id := len(rm.patternsIds) + len(newPatterns) - rule.Patterns[i].internalID = id + rule.Patterns[i].internalID = uint(id) compiledPattern.Id = id newPatterns = append(newPatterns, compiledPattern) duplicatePatterns[regex] = true @@ -235,7 +308,7 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error { startId := len(rm.patterns) for id, pattern := range newPatterns { rm.patterns = append(rm.patterns, pattern) - rm.patternsIds[pattern.String()] = startId + id + rm.patternsIds[pattern.String()] = uint(startId + id) } rm.rules[rule.ID] = *rule diff --git a/rules_manager_test.go b/rules_manager_test.go index 68e3ae4..59ce11c 100644 --- a/rules_manager_test.go +++ b/rules_manager_test.go @@ -13,25 +13,14 @@ func TestAddAndGetAllRules(t *testing.T) { rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl) - checkVersion := func(id RowID) { - timeout := time.Tick(1 * time.Second) - - select { - case database := <-rulesManager.DatabaseUpdateChannel(): - assert.Equal(t, id, database.version) - case <-timeout: - t.Fatal("timeout") - } - } - err := rulesManager.SetFlag(wrapper.Context, "FLAG{test}") assert.NoError(t, err) - checkVersion(rulesManager.rulesByName["flag"].ID) + checkVersion(t, rulesManager, rulesManager.rulesByName["flag"].ID) emptyRule := Rule{Name: "empty", Color: "#fff", Enabled: true} emptyID, err := rulesManager.AddRule(wrapper.Context, emptyRule) assert.NoError(t, err) assert.NotNil(t, emptyID) - checkVersion(emptyID) + checkVersion(t, rulesManager, emptyID) duplicateRule, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#eee"}) assert.Error(t, err) @@ -72,7 +61,7 @@ func TestAddAndGetAllRules(t *testing.T) { rule1ID, err := rulesManager.AddRule(wrapper.Context, rule1) assert.NoError(t, err) assert.NotNil(t, rule1ID) - checkVersion(rule1ID) + checkVersion(t, rulesManager, rule1ID) rule2 := Rule{ Name: "rule2", @@ -86,7 +75,7 @@ func TestAddAndGetAllRules(t *testing.T) { rule2ID, err := rulesManager.AddRule(wrapper.Context, rule2) assert.NoError(t, err) assert.NotNil(t, rule2ID) - checkVersion(rule2ID) + checkVersion(t, rulesManager, rule2ID) rule3 := Rule{ Name: "rule3", @@ -100,7 +89,7 @@ func TestAddAndGetAllRules(t *testing.T) { rule3ID, err := rulesManager.AddRule(wrapper.Context, rule3) assert.NoError(t, err) assert.NotNil(t, rule3ID) - checkVersion(rule3ID) + checkVersion(t, rulesManager, rule3ID) checkRule := func(expected Rule, patternIDs []int) { var rule Rule @@ -109,7 +98,7 @@ func TestAddAndGetAllRules(t *testing.T) { require.NoError(t, err) for i, id := range patternIDs { - rule.Patterns[i].internalID = id + rule.Patterns[i].internalID = uint(id) } assert.Equal(t, expected, rule) assert.Equal(t, expected, rulesManager.rules[expected.ID]) @@ -199,3 +188,97 @@ func TestLoadAndUpdateRules(t *testing.T) { wrapper.Destroy(t) } + +func TestFillWithMatchedRules(t *testing.T) { + wrapper := NewTestStorageWrapper(t) + wrapper.AddCollection(Rules) + + rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl) + err := rulesManager.SetFlag(wrapper.Context, "flag") + require.NoError(t, err) + checkVersion(t, rulesManager, rulesManager.rulesByName["flag"].ID) + + emptyRule, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#fff"}) + require.NoError(t, err) + checkVersion(t, rulesManager, emptyRule) + + conn := &Connection{} + rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{}, map[uint][]PatternSlice{}) + assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules) + + filterRule, err := rulesManager.AddRule(wrapper.Context, Rule{ + Name: "filter", + Color: "#fff", + Filter: Filter{ + ServicePort: 80, + ClientAddress: "10.10.10.10", + ClientPort: 60000, + MinDuration: 2000, + MaxDuration: 4000, + MinBytes: 64, + MaxBytes: 64, + }, + }) + require.NoError(t, err) + checkVersion(t, rulesManager, filterRule) + conn = &Connection{ + SourceIP: "10.10.10.10", + SourcePort: 60000, + DestinationPort: 80, + ClientBytes: 32, + ServerBytes: 32, + StartedAt: time.Now(), + ClosedAt: time.Now().Add(3 * time.Second), + } + rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{}, map[uint][]PatternSlice{}) + assert.ElementsMatch(t, []RowID{emptyRule, filterRule}, conn.MatchedRules) + + patternRule, err := rulesManager.AddRule(wrapper.Context, Rule{ + Name: "pattern", + Color: "#fff", + Patterns: []Pattern{ + {Regex: "pattern1", Direction: DirectionToClient, MinOccurrences: 1}, + {Regex: "pattern2", Direction: DirectionToServer, MaxOccurrences: 2}, + {Regex: "pattern3", Direction: DirectionBoth, MinOccurrences: 2, MaxOccurrences: 2}, + }, + }) + require.NoError(t, err) + checkVersion(t, rulesManager, patternRule) + conn = &Connection{} + rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}, 3: {{0, 0}}}, + map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0}}}) + assert.ElementsMatch(t, []RowID{emptyRule, patternRule}, conn.MatchedRules) + + rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}}, + map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0},{0, 0}}}) + assert.ElementsMatch(t, []RowID{emptyRule, patternRule}, conn.MatchedRules) + + rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}, 3: {{0, 0},{0, 0}}}, + map[uint][]PatternSlice{1: {{0, 0}}}) + assert.ElementsMatch(t, []RowID{emptyRule, patternRule}, conn.MatchedRules) + + rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}, 3: {{0, 0}}}, + map[uint][]PatternSlice{3: {{0, 0}}}) + assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules) + + rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0},{0, 0}}, 3: {{0, 0}}}, + map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0}}}) + assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules) + + rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}, 3: {{0, 0}}}, + map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0},{0, 0}}}) + assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules) + + wrapper.Destroy(t) +} + +func checkVersion(t *testing.T, rulesManager RulesManager, id RowID) { + timeout := time.Tick(1 * time.Second) + + select { + case database := <-rulesManager.DatabaseUpdateChannel(): + assert.Equal(t, id, database.version) + case <-timeout: + t.Fatal("timeout") + } +} |