diff options
-rw-r--r-- | connection_handler_test.go | 4 | ||||
-rw-r--r-- | routes.go | 6 | ||||
-rw-r--r-- | rules_manager.go | 21 | ||||
-rw-r--r-- | rules_manager_test.go | 170 |
4 files changed, 169 insertions, 32 deletions
diff --git a/connection_handler_test.go b/connection_handler_test.go index dd097e4..9fe41af 100644 --- a/connection_handler_test.go +++ b/connection_handler_test.go @@ -186,8 +186,8 @@ func (rm TestRulesManager) GetRule(_ RowID) (Rule, bool) { return Rule{}, false } -func (rm TestRulesManager) UpdateRule(_ context.Context, _ RowID, _ Rule) bool { - return false +func (rm TestRulesManager) UpdateRule(_ context.Context, _ RowID, _ Rule) (bool, error) { + return false, nil } func (rm TestRulesManager) GetRules() []Rule { @@ -57,8 +57,10 @@ func ApplicationRoutes(engine *gin.Engine, rulesManager RulesManager) { return } - updated := rulesManager.UpdateRule(c, id, rule) - if !updated { + updated, err := rulesManager.UpdateRule(c, id, rule) + if err != nil { + badRequest(c, err) + } else if !updated { notFound(c, UnorderedDocument{"id": id}) } else { success(c, rule) diff --git a/rules_manager.go b/rules_manager.go index 57f8768..89b6153 100644 --- a/rules_manager.go +++ b/rules_manager.go @@ -64,7 +64,7 @@ type RulesManager interface { LoadRules() error AddRule(context context.Context, rule Rule) (RowID, error) GetRule(id RowID) (Rule, bool) - UpdateRule(context context.Context, id RowID, rule Rule) bool + UpdateRule(context context.Context, id RowID, rule Rule) (bool, error) GetRules() []Rule FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice) DatabaseUpdateChannel() chan RulesDatabase @@ -138,10 +138,15 @@ func (rm *rulesManagerImpl) GetRule(id RowID) (Rule, bool) { return rule, isPresent } -func (rm *rulesManagerImpl) UpdateRule(context context.Context, id RowID, rule Rule) bool { +func (rm *rulesManagerImpl) UpdateRule(context context.Context, id RowID, rule Rule) (bool, error) { newRule, isPresent := rm.rules[id] if !isPresent { - return false + return false, nil + } + + sameName, isPresent := rm.rulesByName[rule.Name] + if isPresent && sameName.ID != id { + return false, errors.New("already exists another rule with the same name") } updated, err := rm.storage.Update(Rules).Context(context).Filter(OrderedDocument{{"_id", id}}). @@ -161,7 +166,7 @@ func (rm *rulesManagerImpl) UpdateRule(context context.Context, id RowID, rule R rm.mutex.Unlock() } - return updated + return updated, nil } func (rm *rulesManagerImpl) GetRules() []Rule { @@ -201,6 +206,7 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error { } newPatterns := make([]*hyperscan.Pattern, 0, len(rule.Patterns)) + duplicatePatterns := make(map[string]bool) for i, pattern := range rule.Patterns { if err := rm.validate.Struct(pattern); err != nil { return err @@ -210,7 +216,11 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error { if err != nil { return err } - if existingPattern, isPresent := rm.patternsIds[compiledPattern.String()]; isPresent { + regex := compiledPattern.String() + if _, isPresent := duplicatePatterns[regex]; isPresent { + return errors.New("duplicate pattern") + } + if existingPattern, isPresent := rm.patternsIds[regex]; isPresent { rule.Patterns[i].internalID = existingPattern continue } @@ -219,6 +229,7 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error { rule.Patterns[i].internalID = id compiledPattern.Id = id newPatterns = append(newPatterns, compiledPattern) + duplicatePatterns[regex] = true } startId := len(rm.patterns) diff --git a/rules_manager_test.go b/rules_manager_test.go index 53d085d..68e3ae4 100644 --- a/rules_manager_test.go +++ b/rules_manager_test.go @@ -2,11 +2,12 @@ package main import ( "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" "time" ) -func TestAddRule(t *testing.T) { +func TestAddAndGetAllRules(t *testing.T) { wrapper := NewTestStorageWrapper(t) wrapper.AddCollection(Rules) @@ -16,7 +17,7 @@ func TestAddRule(t *testing.T) { timeout := time.Tick(1 * time.Second) select { - case database := <-rulesManager.databaseUpdated: + case database := <-rulesManager.DatabaseUpdateChannel(): assert.Equal(t, id, database.version) case <-timeout: t.Fatal("timeout") @@ -26,17 +27,47 @@ func TestAddRule(t *testing.T) { err := rulesManager.SetFlag(wrapper.Context, "FLAG{test}") assert.NoError(t, err) checkVersion(rulesManager.rulesByName["flag"].ID) - emptyID, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#fff"}) + emptyRule := Rule{Name: "empty", Color: "#fff", Enabled: true} + emptyID, err := rulesManager.AddRule(wrapper.Context, emptyRule) assert.NoError(t, err) assert.NotNil(t, emptyID) checkVersion(emptyID) + duplicateRule, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#eee"}) + assert.Error(t, err) + assert.Zero(t, duplicateRule) + + invalidPattern, err := rulesManager.AddRule(wrapper.Context, Rule{ + Name: "invalidPattern", + Color: "#eee", + Patterns: []Pattern{ + { + Regex: "invalid)", + }, + }, + }) + assert.Error(t, err) + assert.Zero(t, invalidPattern) + rule1 := Rule{ Name: "rule1", Color: "#eee", Patterns: []Pattern{ - {Regex: "nope", Flags: RegexFlags{Caseless: true}}, - }, + { + Regex: "pattern1", + Flags: RegexFlags{ + Caseless: true, + DotAll: true, + MultiLine: true, + SingleMatch: true, + Utf8Mode: true, + UnicodeProperty: true, + }, + MinOccurrences: 1, + MaxOccurrences: 3, + Direction: DirectionBoth, + }}, + Enabled: true, } rule1ID, err := rulesManager.AddRule(wrapper.Context, rule1) assert.NoError(t, err) @@ -47,31 +78,124 @@ func TestAddRule(t *testing.T) { Name: "rule2", Color: "#ddd", Patterns: []Pattern{ - {Regex: "nope", Flags: RegexFlags{Caseless: true}}, - {Regex: "yep"}, + {Regex: "pattern1"}, + {Regex: "pattern2"}, }, + Enabled: true, } rule2ID, err := rulesManager.AddRule(wrapper.Context, rule2) assert.NoError(t, err) assert.NotNil(t, rule2ID) checkVersion(rule2ID) - assert.Len(t, rulesManager.rules, 4) - assert.Len(t, rulesManager.rulesByName, 4) - assert.Len(t, rulesManager.patterns, 3) - assert.Len(t, rulesManager.patternsIds, 3) - assert.Equal(t, emptyID, rulesManager.rules[emptyID].ID) - assert.Equal(t, emptyID, rulesManager.rulesByName["empty"].ID) - assert.Len(t, rulesManager.rules[emptyID].Patterns, 0) - assert.Equal(t, rule1ID, rulesManager.rules[rule1ID].ID) - assert.Equal(t, rule1ID, rulesManager.rulesByName[rule1.Name].ID) - assert.Len(t, rulesManager.rules[rule1ID].Patterns, 1) - assert.Equal(t, 1, rulesManager.rules[rule1ID].Patterns[0].internalID) - assert.Equal(t, rule2ID, rulesManager.rules[rule2ID].ID) - assert.Equal(t, rule2ID, rulesManager.rulesByName[rule2.Name].ID) - assert.Len(t, rulesManager.rules[rule2ID].Patterns, 2) - assert.Equal(t, 1, rulesManager.rules[rule2ID].Patterns[0].internalID) - assert.Equal(t, 2, rulesManager.rules[rule2ID].Patterns[1].internalID) + rule3 := Rule{ + Name: "rule3", + Color: "#ccc", + Patterns: []Pattern{ + {Regex: "pattern2"}, + {Regex: "pattern3"}, + }, + Enabled: true, + } + rule3ID, err := rulesManager.AddRule(wrapper.Context, rule3) + assert.NoError(t, err) + assert.NotNil(t, rule3ID) + checkVersion(rule3ID) + + checkRule := func(expected Rule, patternIDs []int) { + var rule Rule + err := wrapper.Storage.Find(Rules).Context(wrapper.Context). + Filter(OrderedDocument{{"_id", expected.ID}}).First(&rule) + require.NoError(t, err) + + for i, id := range patternIDs { + rule.Patterns[i].internalID = id + } + assert.Equal(t, expected, rule) + assert.Equal(t, expected, rulesManager.rules[expected.ID]) + assert.Equal(t, expected, rulesManager.rulesByName[expected.Name]) + } + + assert.Len(t, rulesManager.rules, 5) + assert.Len(t, rulesManager.rulesByName, 5) + assert.Len(t, rulesManager.patterns, 5) + assert.Len(t, rulesManager.patternsIds, 5) + + emptyRule.ID = emptyID + rule1.ID = rule1ID + rule2.ID = rule2ID + rule3.ID = rule3ID + + checkRule(emptyRule, []int{}) + checkRule(rule1, []int{1}) + checkRule(rule2, []int{2, 3}) + checkRule(rule3, []int{3, 4}) + + assert.Len(t, rulesManager.GetRules(), 5) + assert.ElementsMatch(t, []Rule{rulesManager.rulesByName["flag"], emptyRule, rule1, rule2, rule3}, rulesManager.GetRules()) + + wrapper.Destroy(t) +} + +func TestLoadAndUpdateRules(t *testing.T) { + wrapper := NewTestStorageWrapper(t) + wrapper.AddCollection(Rules) + + expectedIds := []RowID{NewRowID(), NewRowID(), NewRowID(), NewRowID()} + rules := []interface{}{ + Rule{ID: expectedIds[0], Name: "rule1", Color: "#fff", Patterns: []Pattern{ + {Regex: "pattern1", Flags: RegexFlags{Caseless: true}, Direction: DirectionToClient, internalID: 0}, + }}, + Rule{ID: expectedIds[1], Name: "rule2", Color: "#eee", Patterns: []Pattern{ + {Regex: "pattern2", MinOccurrences: 1, MaxOccurrences: 3, Direction: DirectionToServer, internalID: 1}, + }}, + Rule{ID: expectedIds[2], Name: "rule3", Color: "#ddd", Patterns: []Pattern{ + {Regex: "pattern2", Direction: DirectionBoth, internalID: 1}, + {Regex: "pattern3", Flags: RegexFlags{MultiLine: true}, internalID: 2}, + }}, + Rule{ID: expectedIds[3], Name: "rule4", Color: "#ccc", Patterns: []Pattern{ + {Regex: "pattern3", internalID: 3}, + }}, + } + ids, err := wrapper.Storage.Insert(Rules).Context(wrapper.Context).Many(rules) + require.NoError(t, err) + assert.ElementsMatch(t, expectedIds, ids) + + rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl) + err = rulesManager.LoadRules() + require.NoError(t, err) + + rule, isPresent := rulesManager.GetRule(NewRowID()) + assert.Zero(t, rule) + assert.False(t, isPresent) + + for _, objRule := range rules { + expected := objRule.(Rule) + rule, isPresent := rulesManager.GetRule(expected.ID) + assert.True(t, isPresent) + assert.Equal(t, expected, rule) + } + + updated, err := rulesManager.UpdateRule(wrapper.Context, NewRowID(), Rule{}) + assert.False(t, updated) + assert.NoError(t, err) + + updated, err = rulesManager.UpdateRule(wrapper.Context, expectedIds[0], Rule{Name: "rule2", Color: "#fff"}) + assert.False(t, updated) + assert.Error(t, err) + + for _, objRule := range rules { + expected := objRule.(Rule) + expected.Name = expected.ID.Hex() + expected.Color = "#000" + updated, err := rulesManager.UpdateRule(wrapper.Context, expected.ID, expected) + assert.True(t, updated) + assert.NoError(t, err) + + rule, isPresent := rulesManager.GetRule(expected.ID) + assert.True(t, isPresent) + assert.Equal(t, expected, rule) + } wrapper.Destroy(t) } |