aboutsummaryrefslogtreecommitdiff
path: root/rules_manager_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'rules_manager_test.go')
-rw-r--r--rules_manager_test.go117
1 files changed, 100 insertions, 17 deletions
diff --git a/rules_manager_test.go b/rules_manager_test.go
index 68e3ae4..59ce11c 100644
--- a/rules_manager_test.go
+++ b/rules_manager_test.go
@@ -13,25 +13,14 @@ func TestAddAndGetAllRules(t *testing.T) {
rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl)
- checkVersion := func(id RowID) {
- timeout := time.Tick(1 * time.Second)
-
- select {
- case database := <-rulesManager.DatabaseUpdateChannel():
- assert.Equal(t, id, database.version)
- case <-timeout:
- t.Fatal("timeout")
- }
- }
-
err := rulesManager.SetFlag(wrapper.Context, "FLAG{test}")
assert.NoError(t, err)
- checkVersion(rulesManager.rulesByName["flag"].ID)
+ checkVersion(t, rulesManager, rulesManager.rulesByName["flag"].ID)
emptyRule := Rule{Name: "empty", Color: "#fff", Enabled: true}
emptyID, err := rulesManager.AddRule(wrapper.Context, emptyRule)
assert.NoError(t, err)
assert.NotNil(t, emptyID)
- checkVersion(emptyID)
+ checkVersion(t, rulesManager, emptyID)
duplicateRule, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#eee"})
assert.Error(t, err)
@@ -72,7 +61,7 @@ func TestAddAndGetAllRules(t *testing.T) {
rule1ID, err := rulesManager.AddRule(wrapper.Context, rule1)
assert.NoError(t, err)
assert.NotNil(t, rule1ID)
- checkVersion(rule1ID)
+ checkVersion(t, rulesManager, rule1ID)
rule2 := Rule{
Name: "rule2",
@@ -86,7 +75,7 @@ func TestAddAndGetAllRules(t *testing.T) {
rule2ID, err := rulesManager.AddRule(wrapper.Context, rule2)
assert.NoError(t, err)
assert.NotNil(t, rule2ID)
- checkVersion(rule2ID)
+ checkVersion(t, rulesManager, rule2ID)
rule3 := Rule{
Name: "rule3",
@@ -100,7 +89,7 @@ func TestAddAndGetAllRules(t *testing.T) {
rule3ID, err := rulesManager.AddRule(wrapper.Context, rule3)
assert.NoError(t, err)
assert.NotNil(t, rule3ID)
- checkVersion(rule3ID)
+ checkVersion(t, rulesManager, rule3ID)
checkRule := func(expected Rule, patternIDs []int) {
var rule Rule
@@ -109,7 +98,7 @@ func TestAddAndGetAllRules(t *testing.T) {
require.NoError(t, err)
for i, id := range patternIDs {
- rule.Patterns[i].internalID = id
+ rule.Patterns[i].internalID = uint(id)
}
assert.Equal(t, expected, rule)
assert.Equal(t, expected, rulesManager.rules[expected.ID])
@@ -199,3 +188,97 @@ func TestLoadAndUpdateRules(t *testing.T) {
wrapper.Destroy(t)
}
+
+func TestFillWithMatchedRules(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ wrapper.AddCollection(Rules)
+
+ rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl)
+ err := rulesManager.SetFlag(wrapper.Context, "flag")
+ require.NoError(t, err)
+ checkVersion(t, rulesManager, rulesManager.rulesByName["flag"].ID)
+
+ emptyRule, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#fff"})
+ require.NoError(t, err)
+ checkVersion(t, rulesManager, emptyRule)
+
+ conn := &Connection{}
+ rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{}, map[uint][]PatternSlice{})
+ assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules)
+
+ filterRule, err := rulesManager.AddRule(wrapper.Context, Rule{
+ Name: "filter",
+ Color: "#fff",
+ Filter: Filter{
+ ServicePort: 80,
+ ClientAddress: "10.10.10.10",
+ ClientPort: 60000,
+ MinDuration: 2000,
+ MaxDuration: 4000,
+ MinBytes: 64,
+ MaxBytes: 64,
+ },
+ })
+ require.NoError(t, err)
+ checkVersion(t, rulesManager, filterRule)
+ conn = &Connection{
+ SourceIP: "10.10.10.10",
+ SourcePort: 60000,
+ DestinationPort: 80,
+ ClientBytes: 32,
+ ServerBytes: 32,
+ StartedAt: time.Now(),
+ ClosedAt: time.Now().Add(3 * time.Second),
+ }
+ rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{}, map[uint][]PatternSlice{})
+ assert.ElementsMatch(t, []RowID{emptyRule, filterRule}, conn.MatchedRules)
+
+ patternRule, err := rulesManager.AddRule(wrapper.Context, Rule{
+ Name: "pattern",
+ Color: "#fff",
+ Patterns: []Pattern{
+ {Regex: "pattern1", Direction: DirectionToClient, MinOccurrences: 1},
+ {Regex: "pattern2", Direction: DirectionToServer, MaxOccurrences: 2},
+ {Regex: "pattern3", Direction: DirectionBoth, MinOccurrences: 2, MaxOccurrences: 2},
+ },
+ })
+ require.NoError(t, err)
+ checkVersion(t, rulesManager, patternRule)
+ conn = &Connection{}
+ rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}, 3: {{0, 0}}},
+ map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0}}})
+ assert.ElementsMatch(t, []RowID{emptyRule, patternRule}, conn.MatchedRules)
+
+ rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}},
+ map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0},{0, 0}}})
+ assert.ElementsMatch(t, []RowID{emptyRule, patternRule}, conn.MatchedRules)
+
+ rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}, 3: {{0, 0},{0, 0}}},
+ map[uint][]PatternSlice{1: {{0, 0}}})
+ assert.ElementsMatch(t, []RowID{emptyRule, patternRule}, conn.MatchedRules)
+
+ rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}, 3: {{0, 0}}},
+ map[uint][]PatternSlice{3: {{0, 0}}})
+ assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules)
+
+ rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0},{0, 0}}, 3: {{0, 0}}},
+ map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0}}})
+ assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules)
+
+ rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}, 3: {{0, 0}}},
+ map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0},{0, 0}}})
+ assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules)
+
+ wrapper.Destroy(t)
+}
+
+func checkVersion(t *testing.T, rulesManager RulesManager, id RowID) {
+ timeout := time.Tick(1 * time.Second)
+
+ select {
+ case database := <-rulesManager.DatabaseUpdateChannel():
+ assert.Equal(t, id, database.version)
+ case <-timeout:
+ t.Fatal("timeout")
+ }
+}