aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--application_router_test.go2
-rw-r--r--rules_manager.go46
-rw-r--r--rules_manager_test.go15
3 files changed, 33 insertions, 30 deletions
diff --git a/application_router_test.go b/application_router_test.go
index ec6d151..27c3651 100644
--- a/application_router_test.go
+++ b/application_router_test.go
@@ -109,7 +109,7 @@ func TestRulesApi(t *testing.T) {
var rules []Rule
assert.Equal(t, http.StatusOK, w.Code)
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &rules))
- assert.Len(t, rules, 3)
+ assert.Len(t, rules, 4)
toolkit.wrapper.Destroy(t)
}
diff --git a/rules_manager.go b/rules_manager.go
index 0e6c3d1..327e4ec 100644
--- a/rules_manager.go
+++ b/rules_manager.go
@@ -121,24 +121,22 @@ func LoadRulesManager(storage Storage, flagRegex string) (RulesManager, error) {
// if there are no rules in database (e.g. first run), set flagRegex as first rule
if len(rulesManager.rules) == 0 {
- go func() {
- _, _ = rulesManager.AddRule(context.Background(), Rule{
- Name: "flag_out",
- Color: "#e53935",
- Notes: "Mark connections where the flags are stolen",
- Patterns: []Pattern{
- {Regex: flagRegex, Direction: DirectionToClient, Flags: RegexFlags{Utf8Mode: true}},
- },
- })
- _, _ = rulesManager.AddRule(context.Background(), Rule{
- Name: "flag_in",
- Color: "#43A047",
- Notes: "Mark connections where the flags are placed",
- Patterns: []Pattern{
- {Regex: flagRegex, Direction: DirectionToServer, Flags: RegexFlags{Utf8Mode: true}},
- },
- })
- }()
+ _, _ = rulesManager.AddRule(context.Background(), Rule{
+ Name: "flag_out",
+ Color: "#e53935",
+ Notes: "Mark connections where the flags are stolen",
+ Patterns: []Pattern{
+ {Regex: flagRegex, Direction: DirectionToClient, Flags: RegexFlags{Utf8Mode: true}},
+ },
+ })
+ _, _ = rulesManager.AddRule(context.Background(), Rule{
+ Name: "flag_in",
+ Color: "#43A047",
+ Notes: "Mark connections where the flags are placed",
+ Patterns: []Pattern{
+ {Regex: flagRegex, Direction: DirectionToServer, Flags: RegexFlags{Utf8Mode: true}},
+ },
+ })
} else {
if err := rulesManager.generateDatabase(rules[len(rules)-1].ID); err != nil {
return nil, err
@@ -348,11 +346,13 @@ func (rm *rulesManagerImpl) generateDatabase(version RowID) error {
return err
}
- rm.databaseUpdated <- RulesDatabase{
- database: database,
- databaseSize: len(rm.patterns),
- version: version,
- }
+ go func() {
+ rm.databaseUpdated <- RulesDatabase{
+ database: database,
+ databaseSize: len(rm.patterns),
+ version: version,
+ }
+ }()
return nil
}
diff --git a/rules_manager_test.go b/rules_manager_test.go
index dded096..215e601 100644
--- a/rules_manager_test.go
+++ b/rules_manager_test.go
@@ -31,7 +31,8 @@ func TestAddAndGetAllRules(t *testing.T) {
rulesManager, err := LoadRulesManager(wrapper.Storage, "FLAG{test}")
require.NoError(t, err)
impl := rulesManager.(*rulesManagerImpl)
- checkVersion(t, rulesManager, impl.rulesByName["flag"].ID)
+ checkVersion(t, rulesManager, impl.rulesByName["flag_out"].ID)
+ checkVersion(t, rulesManager, impl.rulesByName["flag_in"].ID)
emptyRule := Rule{Name: "empty", Color: "#fff", Enabled: true}
emptyID, err := rulesManager.AddRule(wrapper.Context, emptyRule)
assert.NoError(t, err)
@@ -120,8 +121,8 @@ func TestAddAndGetAllRules(t *testing.T) {
assert.Equal(t, expected, impl.rulesByName[expected.Name])
}
- assert.Len(t, impl.rules, 5)
- assert.Len(t, impl.rulesByName, 5)
+ assert.Len(t, impl.rules, 6)
+ assert.Len(t, impl.rulesByName, 6)
assert.Len(t, impl.patterns, 5)
assert.Len(t, impl.patternsIds, 5)
@@ -135,8 +136,9 @@ func TestAddAndGetAllRules(t *testing.T) {
checkRule(rule2, []int{2, 3})
checkRule(rule3, []int{3, 4})
- assert.Len(t, rulesManager.GetRules(), 5)
- assert.ElementsMatch(t, []Rule{impl.rulesByName["flag"], emptyRule, rule1, rule2, rule3}, rulesManager.GetRules())
+ assert.Len(t, rulesManager.GetRules(), 6)
+ assert.ElementsMatch(t, []Rule{impl.rulesByName["flag_out"], impl.rulesByName["flag_in"], emptyRule,
+ rule1, rule2, rule3}, rulesManager.GetRules())
wrapper.Destroy(t)
}
@@ -210,7 +212,8 @@ func TestFillWithMatchedRules(t *testing.T) {
rulesManager, err := LoadRulesManager(wrapper.Storage, "FLAG{test}")
require.NoError(t, err)
impl := rulesManager.(*rulesManagerImpl)
- checkVersion(t, rulesManager, impl.rulesByName["flag"].ID)
+ checkVersion(t, rulesManager, impl.rulesByName["flag_out"].ID)
+ checkVersion(t, rulesManager, impl.rulesByName["flag_in"].ID)
emptyRule, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#fff"})
require.NoError(t, err)