diff options
Diffstat (limited to 'rules_manager_test.go')
-rw-r--r-- | rules_manager_test.go | 117 |
1 files changed, 100 insertions, 17 deletions
diff --git a/rules_manager_test.go b/rules_manager_test.go index 68e3ae4..59ce11c 100644 --- a/rules_manager_test.go +++ b/rules_manager_test.go @@ -13,25 +13,14 @@ func TestAddAndGetAllRules(t *testing.T) { rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl) - checkVersion := func(id RowID) { - timeout := time.Tick(1 * time.Second) - - select { - case database := <-rulesManager.DatabaseUpdateChannel(): - assert.Equal(t, id, database.version) - case <-timeout: - t.Fatal("timeout") - } - } - err := rulesManager.SetFlag(wrapper.Context, "FLAG{test}") assert.NoError(t, err) - checkVersion(rulesManager.rulesByName["flag"].ID) + checkVersion(t, rulesManager, rulesManager.rulesByName["flag"].ID) emptyRule := Rule{Name: "empty", Color: "#fff", Enabled: true} emptyID, err := rulesManager.AddRule(wrapper.Context, emptyRule) assert.NoError(t, err) assert.NotNil(t, emptyID) - checkVersion(emptyID) + checkVersion(t, rulesManager, emptyID) duplicateRule, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#eee"}) assert.Error(t, err) @@ -72,7 +61,7 @@ func TestAddAndGetAllRules(t *testing.T) { rule1ID, err := rulesManager.AddRule(wrapper.Context, rule1) assert.NoError(t, err) assert.NotNil(t, rule1ID) - checkVersion(rule1ID) + checkVersion(t, rulesManager, rule1ID) rule2 := Rule{ Name: "rule2", @@ -86,7 +75,7 @@ func TestAddAndGetAllRules(t *testing.T) { rule2ID, err := rulesManager.AddRule(wrapper.Context, rule2) assert.NoError(t, err) assert.NotNil(t, rule2ID) - checkVersion(rule2ID) + checkVersion(t, rulesManager, rule2ID) rule3 := Rule{ Name: "rule3", @@ -100,7 +89,7 @@ func TestAddAndGetAllRules(t *testing.T) { rule3ID, err := rulesManager.AddRule(wrapper.Context, rule3) assert.NoError(t, err) assert.NotNil(t, rule3ID) - checkVersion(rule3ID) + checkVersion(t, rulesManager, rule3ID) checkRule := func(expected Rule, patternIDs []int) { var rule Rule @@ -109,7 +98,7 @@ func TestAddAndGetAllRules(t *testing.T) { require.NoError(t, err) for i, id := range patternIDs { - rule.Patterns[i].internalID = id + rule.Patterns[i].internalID = uint(id) } assert.Equal(t, expected, rule) assert.Equal(t, expected, rulesManager.rules[expected.ID]) @@ -199,3 +188,97 @@ func TestLoadAndUpdateRules(t *testing.T) { wrapper.Destroy(t) } + +func TestFillWithMatchedRules(t *testing.T) { + wrapper := NewTestStorageWrapper(t) + wrapper.AddCollection(Rules) + + rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl) + err := rulesManager.SetFlag(wrapper.Context, "flag") + require.NoError(t, err) + checkVersion(t, rulesManager, rulesManager.rulesByName["flag"].ID) + + emptyRule, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#fff"}) + require.NoError(t, err) + checkVersion(t, rulesManager, emptyRule) + + conn := &Connection{} + rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{}, map[uint][]PatternSlice{}) + assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules) + + filterRule, err := rulesManager.AddRule(wrapper.Context, Rule{ + Name: "filter", + Color: "#fff", + Filter: Filter{ + ServicePort: 80, + ClientAddress: "10.10.10.10", + ClientPort: 60000, + MinDuration: 2000, + MaxDuration: 4000, + MinBytes: 64, + MaxBytes: 64, + }, + }) + require.NoError(t, err) + checkVersion(t, rulesManager, filterRule) + conn = &Connection{ + SourceIP: "10.10.10.10", + SourcePort: 60000, + DestinationPort: 80, + ClientBytes: 32, + ServerBytes: 32, + StartedAt: time.Now(), + ClosedAt: time.Now().Add(3 * time.Second), + } + rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{}, map[uint][]PatternSlice{}) + assert.ElementsMatch(t, []RowID{emptyRule, filterRule}, conn.MatchedRules) + + patternRule, err := rulesManager.AddRule(wrapper.Context, Rule{ + Name: "pattern", + Color: "#fff", + Patterns: []Pattern{ + {Regex: "pattern1", Direction: DirectionToClient, MinOccurrences: 1}, + {Regex: "pattern2", Direction: DirectionToServer, MaxOccurrences: 2}, + {Regex: "pattern3", Direction: DirectionBoth, MinOccurrences: 2, MaxOccurrences: 2}, + }, + }) + require.NoError(t, err) + checkVersion(t, rulesManager, patternRule) + conn = &Connection{} + rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}, 3: {{0, 0}}}, + map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0}}}) + assert.ElementsMatch(t, []RowID{emptyRule, patternRule}, conn.MatchedRules) + + rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}}, + map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0},{0, 0}}}) + assert.ElementsMatch(t, []RowID{emptyRule, patternRule}, conn.MatchedRules) + + rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}, 3: {{0, 0},{0, 0}}}, + map[uint][]PatternSlice{1: {{0, 0}}}) + assert.ElementsMatch(t, []RowID{emptyRule, patternRule}, conn.MatchedRules) + + rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}, 3: {{0, 0}}}, + map[uint][]PatternSlice{3: {{0, 0}}}) + assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules) + + rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0},{0, 0}}, 3: {{0, 0}}}, + map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0}}}) + assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules) + + rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}, 3: {{0, 0}}}, + map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0},{0, 0}}}) + assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules) + + wrapper.Destroy(t) +} + +func checkVersion(t *testing.T, rulesManager RulesManager, id RowID) { + timeout := time.Tick(1 * time.Second) + + select { + case database := <-rulesManager.DatabaseUpdateChannel(): + assert.Equal(t, id, database.version) + case <-timeout: + t.Fatal("timeout") + } +} |