aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--connection_handler_test.go4
-rw-r--r--routes.go6
-rw-r--r--rules_manager.go21
-rw-r--r--rules_manager_test.go170
4 files changed, 169 insertions, 32 deletions
diff --git a/connection_handler_test.go b/connection_handler_test.go
index dd097e4..9fe41af 100644
--- a/connection_handler_test.go
+++ b/connection_handler_test.go
@@ -186,8 +186,8 @@ func (rm TestRulesManager) GetRule(_ RowID) (Rule, bool) {
return Rule{}, false
}
-func (rm TestRulesManager) UpdateRule(_ context.Context, _ RowID, _ Rule) bool {
- return false
+func (rm TestRulesManager) UpdateRule(_ context.Context, _ RowID, _ Rule) (bool, error) {
+ return false, nil
}
func (rm TestRulesManager) GetRules() []Rule {
diff --git a/routes.go b/routes.go
index 111dd78..4599b8f 100644
--- a/routes.go
+++ b/routes.go
@@ -57,8 +57,10 @@ func ApplicationRoutes(engine *gin.Engine, rulesManager RulesManager) {
return
}
- updated := rulesManager.UpdateRule(c, id, rule)
- if !updated {
+ updated, err := rulesManager.UpdateRule(c, id, rule)
+ if err != nil {
+ badRequest(c, err)
+ } else if !updated {
notFound(c, UnorderedDocument{"id": id})
} else {
success(c, rule)
diff --git a/rules_manager.go b/rules_manager.go
index 57f8768..89b6153 100644
--- a/rules_manager.go
+++ b/rules_manager.go
@@ -64,7 +64,7 @@ type RulesManager interface {
LoadRules() error
AddRule(context context.Context, rule Rule) (RowID, error)
GetRule(id RowID) (Rule, bool)
- UpdateRule(context context.Context, id RowID, rule Rule) bool
+ UpdateRule(context context.Context, id RowID, rule Rule) (bool, error)
GetRules() []Rule
FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice)
DatabaseUpdateChannel() chan RulesDatabase
@@ -138,10 +138,15 @@ func (rm *rulesManagerImpl) GetRule(id RowID) (Rule, bool) {
return rule, isPresent
}
-func (rm *rulesManagerImpl) UpdateRule(context context.Context, id RowID, rule Rule) bool {
+func (rm *rulesManagerImpl) UpdateRule(context context.Context, id RowID, rule Rule) (bool, error) {
newRule, isPresent := rm.rules[id]
if !isPresent {
- return false
+ return false, nil
+ }
+
+ sameName, isPresent := rm.rulesByName[rule.Name]
+ if isPresent && sameName.ID != id {
+ return false, errors.New("already exists another rule with the same name")
}
updated, err := rm.storage.Update(Rules).Context(context).Filter(OrderedDocument{{"_id", id}}).
@@ -161,7 +166,7 @@ func (rm *rulesManagerImpl) UpdateRule(context context.Context, id RowID, rule R
rm.mutex.Unlock()
}
- return updated
+ return updated, nil
}
func (rm *rulesManagerImpl) GetRules() []Rule {
@@ -201,6 +206,7 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error {
}
newPatterns := make([]*hyperscan.Pattern, 0, len(rule.Patterns))
+ duplicatePatterns := make(map[string]bool)
for i, pattern := range rule.Patterns {
if err := rm.validate.Struct(pattern); err != nil {
return err
@@ -210,7 +216,11 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error {
if err != nil {
return err
}
- if existingPattern, isPresent := rm.patternsIds[compiledPattern.String()]; isPresent {
+ regex := compiledPattern.String()
+ if _, isPresent := duplicatePatterns[regex]; isPresent {
+ return errors.New("duplicate pattern")
+ }
+ if existingPattern, isPresent := rm.patternsIds[regex]; isPresent {
rule.Patterns[i].internalID = existingPattern
continue
}
@@ -219,6 +229,7 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error {
rule.Patterns[i].internalID = id
compiledPattern.Id = id
newPatterns = append(newPatterns, compiledPattern)
+ duplicatePatterns[regex] = true
}
startId := len(rm.patterns)
diff --git a/rules_manager_test.go b/rules_manager_test.go
index 53d085d..68e3ae4 100644
--- a/rules_manager_test.go
+++ b/rules_manager_test.go
@@ -2,11 +2,12 @@ package main
import (
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
"testing"
"time"
)
-func TestAddRule(t *testing.T) {
+func TestAddAndGetAllRules(t *testing.T) {
wrapper := NewTestStorageWrapper(t)
wrapper.AddCollection(Rules)
@@ -16,7 +17,7 @@ func TestAddRule(t *testing.T) {
timeout := time.Tick(1 * time.Second)
select {
- case database := <-rulesManager.databaseUpdated:
+ case database := <-rulesManager.DatabaseUpdateChannel():
assert.Equal(t, id, database.version)
case <-timeout:
t.Fatal("timeout")
@@ -26,17 +27,47 @@ func TestAddRule(t *testing.T) {
err := rulesManager.SetFlag(wrapper.Context, "FLAG{test}")
assert.NoError(t, err)
checkVersion(rulesManager.rulesByName["flag"].ID)
- emptyID, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#fff"})
+ emptyRule := Rule{Name: "empty", Color: "#fff", Enabled: true}
+ emptyID, err := rulesManager.AddRule(wrapper.Context, emptyRule)
assert.NoError(t, err)
assert.NotNil(t, emptyID)
checkVersion(emptyID)
+ duplicateRule, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#eee"})
+ assert.Error(t, err)
+ assert.Zero(t, duplicateRule)
+
+ invalidPattern, err := rulesManager.AddRule(wrapper.Context, Rule{
+ Name: "invalidPattern",
+ Color: "#eee",
+ Patterns: []Pattern{
+ {
+ Regex: "invalid)",
+ },
+ },
+ })
+ assert.Error(t, err)
+ assert.Zero(t, invalidPattern)
+
rule1 := Rule{
Name: "rule1",
Color: "#eee",
Patterns: []Pattern{
- {Regex: "nope", Flags: RegexFlags{Caseless: true}},
- },
+ {
+ Regex: "pattern1",
+ Flags: RegexFlags{
+ Caseless: true,
+ DotAll: true,
+ MultiLine: true,
+ SingleMatch: true,
+ Utf8Mode: true,
+ UnicodeProperty: true,
+ },
+ MinOccurrences: 1,
+ MaxOccurrences: 3,
+ Direction: DirectionBoth,
+ }},
+ Enabled: true,
}
rule1ID, err := rulesManager.AddRule(wrapper.Context, rule1)
assert.NoError(t, err)
@@ -47,31 +78,124 @@ func TestAddRule(t *testing.T) {
Name: "rule2",
Color: "#ddd",
Patterns: []Pattern{
- {Regex: "nope", Flags: RegexFlags{Caseless: true}},
- {Regex: "yep"},
+ {Regex: "pattern1"},
+ {Regex: "pattern2"},
},
+ Enabled: true,
}
rule2ID, err := rulesManager.AddRule(wrapper.Context, rule2)
assert.NoError(t, err)
assert.NotNil(t, rule2ID)
checkVersion(rule2ID)
- assert.Len(t, rulesManager.rules, 4)
- assert.Len(t, rulesManager.rulesByName, 4)
- assert.Len(t, rulesManager.patterns, 3)
- assert.Len(t, rulesManager.patternsIds, 3)
- assert.Equal(t, emptyID, rulesManager.rules[emptyID].ID)
- assert.Equal(t, emptyID, rulesManager.rulesByName["empty"].ID)
- assert.Len(t, rulesManager.rules[emptyID].Patterns, 0)
- assert.Equal(t, rule1ID, rulesManager.rules[rule1ID].ID)
- assert.Equal(t, rule1ID, rulesManager.rulesByName[rule1.Name].ID)
- assert.Len(t, rulesManager.rules[rule1ID].Patterns, 1)
- assert.Equal(t, 1, rulesManager.rules[rule1ID].Patterns[0].internalID)
- assert.Equal(t, rule2ID, rulesManager.rules[rule2ID].ID)
- assert.Equal(t, rule2ID, rulesManager.rulesByName[rule2.Name].ID)
- assert.Len(t, rulesManager.rules[rule2ID].Patterns, 2)
- assert.Equal(t, 1, rulesManager.rules[rule2ID].Patterns[0].internalID)
- assert.Equal(t, 2, rulesManager.rules[rule2ID].Patterns[1].internalID)
+ rule3 := Rule{
+ Name: "rule3",
+ Color: "#ccc",
+ Patterns: []Pattern{
+ {Regex: "pattern2"},
+ {Regex: "pattern3"},
+ },
+ Enabled: true,
+ }
+ rule3ID, err := rulesManager.AddRule(wrapper.Context, rule3)
+ assert.NoError(t, err)
+ assert.NotNil(t, rule3ID)
+ checkVersion(rule3ID)
+
+ checkRule := func(expected Rule, patternIDs []int) {
+ var rule Rule
+ err := wrapper.Storage.Find(Rules).Context(wrapper.Context).
+ Filter(OrderedDocument{{"_id", expected.ID}}).First(&rule)
+ require.NoError(t, err)
+
+ for i, id := range patternIDs {
+ rule.Patterns[i].internalID = id
+ }
+ assert.Equal(t, expected, rule)
+ assert.Equal(t, expected, rulesManager.rules[expected.ID])
+ assert.Equal(t, expected, rulesManager.rulesByName[expected.Name])
+ }
+
+ assert.Len(t, rulesManager.rules, 5)
+ assert.Len(t, rulesManager.rulesByName, 5)
+ assert.Len(t, rulesManager.patterns, 5)
+ assert.Len(t, rulesManager.patternsIds, 5)
+
+ emptyRule.ID = emptyID
+ rule1.ID = rule1ID
+ rule2.ID = rule2ID
+ rule3.ID = rule3ID
+
+ checkRule(emptyRule, []int{})
+ checkRule(rule1, []int{1})
+ checkRule(rule2, []int{2, 3})
+ checkRule(rule3, []int{3, 4})
+
+ assert.Len(t, rulesManager.GetRules(), 5)
+ assert.ElementsMatch(t, []Rule{rulesManager.rulesByName["flag"], emptyRule, rule1, rule2, rule3}, rulesManager.GetRules())
+
+ wrapper.Destroy(t)
+}
+
+func TestLoadAndUpdateRules(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ wrapper.AddCollection(Rules)
+
+ expectedIds := []RowID{NewRowID(), NewRowID(), NewRowID(), NewRowID()}
+ rules := []interface{}{
+ Rule{ID: expectedIds[0], Name: "rule1", Color: "#fff", Patterns: []Pattern{
+ {Regex: "pattern1", Flags: RegexFlags{Caseless: true}, Direction: DirectionToClient, internalID: 0},
+ }},
+ Rule{ID: expectedIds[1], Name: "rule2", Color: "#eee", Patterns: []Pattern{
+ {Regex: "pattern2", MinOccurrences: 1, MaxOccurrences: 3, Direction: DirectionToServer, internalID: 1},
+ }},
+ Rule{ID: expectedIds[2], Name: "rule3", Color: "#ddd", Patterns: []Pattern{
+ {Regex: "pattern2", Direction: DirectionBoth, internalID: 1},
+ {Regex: "pattern3", Flags: RegexFlags{MultiLine: true}, internalID: 2},
+ }},
+ Rule{ID: expectedIds[3], Name: "rule4", Color: "#ccc", Patterns: []Pattern{
+ {Regex: "pattern3", internalID: 3},
+ }},
+ }
+ ids, err := wrapper.Storage.Insert(Rules).Context(wrapper.Context).Many(rules)
+ require.NoError(t, err)
+ assert.ElementsMatch(t, expectedIds, ids)
+
+ rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl)
+ err = rulesManager.LoadRules()
+ require.NoError(t, err)
+
+ rule, isPresent := rulesManager.GetRule(NewRowID())
+ assert.Zero(t, rule)
+ assert.False(t, isPresent)
+
+ for _, objRule := range rules {
+ expected := objRule.(Rule)
+ rule, isPresent := rulesManager.GetRule(expected.ID)
+ assert.True(t, isPresent)
+ assert.Equal(t, expected, rule)
+ }
+
+ updated, err := rulesManager.UpdateRule(wrapper.Context, NewRowID(), Rule{})
+ assert.False(t, updated)
+ assert.NoError(t, err)
+
+ updated, err = rulesManager.UpdateRule(wrapper.Context, expectedIds[0], Rule{Name: "rule2", Color: "#fff"})
+ assert.False(t, updated)
+ assert.Error(t, err)
+
+ for _, objRule := range rules {
+ expected := objRule.(Rule)
+ expected.Name = expected.ID.Hex()
+ expected.Color = "#000"
+ updated, err := rulesManager.UpdateRule(wrapper.Context, expected.ID, expected)
+ assert.True(t, updated)
+ assert.NoError(t, err)
+
+ rule, isPresent := rulesManager.GetRule(expected.ID)
+ assert.True(t, isPresent)
+ assert.Equal(t, expected, rule)
+ }
wrapper.Destroy(t)
}