diff options
-rw-r--r-- | application_router_test.go | 2 | ||||
-rw-r--r-- | rules_manager.go | 46 | ||||
-rw-r--r-- | rules_manager_test.go | 15 |
3 files changed, 33 insertions, 30 deletions
diff --git a/application_router_test.go b/application_router_test.go index ec6d151..27c3651 100644 --- a/application_router_test.go +++ b/application_router_test.go @@ -109,7 +109,7 @@ func TestRulesApi(t *testing.T) { var rules []Rule assert.Equal(t, http.StatusOK, w.Code) assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &rules)) - assert.Len(t, rules, 3) + assert.Len(t, rules, 4) toolkit.wrapper.Destroy(t) } diff --git a/rules_manager.go b/rules_manager.go index 0e6c3d1..327e4ec 100644 --- a/rules_manager.go +++ b/rules_manager.go @@ -121,24 +121,22 @@ func LoadRulesManager(storage Storage, flagRegex string) (RulesManager, error) { // if there are no rules in database (e.g. first run), set flagRegex as first rule if len(rulesManager.rules) == 0 { - go func() { - _, _ = rulesManager.AddRule(context.Background(), Rule{ - Name: "flag_out", - Color: "#e53935", - Notes: "Mark connections where the flags are stolen", - Patterns: []Pattern{ - {Regex: flagRegex, Direction: DirectionToClient, Flags: RegexFlags{Utf8Mode: true}}, - }, - }) - _, _ = rulesManager.AddRule(context.Background(), Rule{ - Name: "flag_in", - Color: "#43A047", - Notes: "Mark connections where the flags are placed", - Patterns: []Pattern{ - {Regex: flagRegex, Direction: DirectionToServer, Flags: RegexFlags{Utf8Mode: true}}, - }, - }) - }() + _, _ = rulesManager.AddRule(context.Background(), Rule{ + Name: "flag_out", + Color: "#e53935", + Notes: "Mark connections where the flags are stolen", + Patterns: []Pattern{ + {Regex: flagRegex, Direction: DirectionToClient, Flags: RegexFlags{Utf8Mode: true}}, + }, + }) + _, _ = rulesManager.AddRule(context.Background(), Rule{ + Name: "flag_in", + Color: "#43A047", + Notes: "Mark connections where the flags are placed", + Patterns: []Pattern{ + {Regex: flagRegex, Direction: DirectionToServer, Flags: RegexFlags{Utf8Mode: true}}, + }, + }) } else { if err := rulesManager.generateDatabase(rules[len(rules)-1].ID); err != nil { return nil, err @@ -348,11 +346,13 @@ func (rm *rulesManagerImpl) generateDatabase(version RowID) error { return err } - rm.databaseUpdated <- RulesDatabase{ - database: database, - databaseSize: len(rm.patterns), - version: version, - } + go func() { + rm.databaseUpdated <- RulesDatabase{ + database: database, + databaseSize: len(rm.patterns), + version: version, + } + }() return nil } diff --git a/rules_manager_test.go b/rules_manager_test.go index dded096..215e601 100644 --- a/rules_manager_test.go +++ b/rules_manager_test.go @@ -31,7 +31,8 @@ func TestAddAndGetAllRules(t *testing.T) { rulesManager, err := LoadRulesManager(wrapper.Storage, "FLAG{test}") require.NoError(t, err) impl := rulesManager.(*rulesManagerImpl) - checkVersion(t, rulesManager, impl.rulesByName["flag"].ID) + checkVersion(t, rulesManager, impl.rulesByName["flag_out"].ID) + checkVersion(t, rulesManager, impl.rulesByName["flag_in"].ID) emptyRule := Rule{Name: "empty", Color: "#fff", Enabled: true} emptyID, err := rulesManager.AddRule(wrapper.Context, emptyRule) assert.NoError(t, err) @@ -120,8 +121,8 @@ func TestAddAndGetAllRules(t *testing.T) { assert.Equal(t, expected, impl.rulesByName[expected.Name]) } - assert.Len(t, impl.rules, 5) - assert.Len(t, impl.rulesByName, 5) + assert.Len(t, impl.rules, 6) + assert.Len(t, impl.rulesByName, 6) assert.Len(t, impl.patterns, 5) assert.Len(t, impl.patternsIds, 5) @@ -135,8 +136,9 @@ func TestAddAndGetAllRules(t *testing.T) { checkRule(rule2, []int{2, 3}) checkRule(rule3, []int{3, 4}) - assert.Len(t, rulesManager.GetRules(), 5) - assert.ElementsMatch(t, []Rule{impl.rulesByName["flag"], emptyRule, rule1, rule2, rule3}, rulesManager.GetRules()) + assert.Len(t, rulesManager.GetRules(), 6) + assert.ElementsMatch(t, []Rule{impl.rulesByName["flag_out"], impl.rulesByName["flag_in"], emptyRule, + rule1, rule2, rule3}, rulesManager.GetRules()) wrapper.Destroy(t) } @@ -210,7 +212,8 @@ func TestFillWithMatchedRules(t *testing.T) { rulesManager, err := LoadRulesManager(wrapper.Storage, "FLAG{test}") require.NoError(t, err) impl := rulesManager.(*rulesManagerImpl) - checkVersion(t, rulesManager, impl.rulesByName["flag"].ID) + checkVersion(t, rulesManager, impl.rulesByName["flag_out"].ID) + checkVersion(t, rulesManager, impl.rulesByName["flag_in"].ID) emptyRule, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#fff"}) require.NoError(t, err) |