aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--connection_handler_test.go4
-rw-r--r--routes.go2
-rw-r--r--rules_manager.go87
-rw-r--r--rules_manager_test.go117
4 files changed, 185 insertions, 25 deletions
diff --git a/connection_handler_test.go b/connection_handler_test.go
index 9fe41af..7b00f6e 100644
--- a/connection_handler_test.go
+++ b/connection_handler_test.go
@@ -194,6 +194,10 @@ func (rm TestRulesManager) GetRules() []Rule {
return nil
}
+func (rm TestRulesManager) SetFlag(_ context.Context, _ string) error {
+ return nil
+}
+
func (rm TestRulesManager) FillWithMatchedRules(_ *Connection, _ map[uint][]PatternSlice, _ map[uint][]PatternSlice) {
}
diff --git a/routes.go b/routes.go
index 4599b8f..b628a9c 100644
--- a/routes.go
+++ b/routes.go
@@ -95,7 +95,7 @@ func badRequest(c *gin.Context, err error) {
}
func unprocessableEntity(c *gin.Context, err error) {
- c.JSON(http.StatusOK, UnorderedDocument{"result": "error", "error": err.Error()})
+ c.JSON(http.StatusUnprocessableEntity, UnorderedDocument{"result": "error", "error": err.Error()})
}
func notFound(c *gin.Context, obj interface{}) {
diff --git a/rules_manager.go b/rules_manager.go
index 89b6153..fdd3eaa 100644
--- a/rules_manager.go
+++ b/rules_manager.go
@@ -25,17 +25,17 @@ type RegexFlags struct {
}
type Pattern struct {
- Regex string `json:"regex" binding:"min=1" bson:"regex"`
+ Regex string `json:"regex" binding:"required,min=1" bson:"regex"`
Flags RegexFlags `json:"flags" bson:"flags,omitempty"`
MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"`
MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"`
Direction uint8 `json:"direction" binding:"omitempty,max=2" bson:"direction,omitempty"`
- internalID int
+ internalID uint
}
type Filter struct {
ServicePort uint16 `json:"service_port" bson:"service_port,omitempty"`
- ClientAddress string `json:"client_address" binding:"omitempty,ip_addr" bson:"client_address,omitempty"`
+ ClientAddress string `json:"client_address" binding:"omitempty,ip" bson:"client_address,omitempty"`
ClientPort uint16 `json:"client_port" bson:"client_port,omitempty"`
MinDuration uint `json:"min_duration" bson:"min_duration,omitempty"`
MaxDuration uint `json:"max_duration" binding:"omitempty,gtefield=MinDuration" bson:"max_duration,omitempty"`
@@ -66,6 +66,7 @@ type RulesManager interface {
GetRule(id RowID) (Rule, bool)
UpdateRule(context context.Context, id RowID, rule Rule) (bool, error)
GetRules() []Rule
+ SetFlag(context context.Context, flagRegex string) error
FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice)
DatabaseUpdateChannel() chan RulesDatabase
}
@@ -75,7 +76,7 @@ type rulesManagerImpl struct {
rules map[RowID]Rule
rulesByName map[string]Rule
patterns []*hyperscan.Pattern
- patternsIds map[string]int
+ patternsIds map[string]uint
mutex sync.Mutex
databaseUpdated chan RulesDatabase
validate *validator.Validate
@@ -87,7 +88,7 @@ func NewRulesManager(storage Storage) RulesManager {
rules: make(map[RowID]Rule),
rulesByName: make(map[string]Rule),
patterns: make([]*hyperscan.Pattern, 0),
- patternsIds: make(map[string]int),
+ patternsIds: make(map[string]uint),
mutex: sync.Mutex{},
databaseUpdated: make(chan RulesDatabase, 1),
validate: validator.New(),
@@ -194,6 +195,78 @@ func (rm *rulesManagerImpl) SetFlag(context context.Context, flagRegex string) e
func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice,
serverMatches map[uint][]PatternSlice) {
+ rm.mutex.Lock()
+
+ filterFunctions := []func (rule Rule)bool {
+ func(rule Rule) bool {
+ return rule.Filter.ClientAddress == "" || connection.SourceIP == rule.Filter.ClientAddress
+ },
+ func(rule Rule) bool {
+ return rule.Filter.ClientPort == 0 || connection.SourcePort == rule.Filter.ClientPort
+ },
+ func(rule Rule) bool {
+ return rule.Filter.ServicePort == 0 || connection.DestinationPort == rule.Filter.ServicePort
+ },
+ func(rule Rule) bool {
+ return rule.Filter.MinDuration == 0 || uint(connection.ClosedAt.Sub(connection.StartedAt).Milliseconds()) >=
+ rule.Filter.MinDuration
+ },
+ func(rule Rule) bool {
+ return rule.Filter.MaxDuration == 0 || uint(connection.ClosedAt.Sub(connection.StartedAt).Milliseconds()) <=
+ rule.Filter.MaxDuration
+ },
+ func(rule Rule) bool {
+ return rule.Filter.MinBytes == 0 || uint(connection.ClientBytes + connection.ServerBytes) >=
+ rule.Filter.MinBytes
+ },
+ func(rule Rule) bool {
+ return rule.Filter.MaxBytes == 0 || uint(connection.ClientBytes + connection.ServerBytes) <=
+ rule.Filter.MinBytes
+ },
+ }
+
+ connection.MatchedRules = make([]RowID, 0)
+ for _, rule := range rm.rules {
+ matching := true
+ for _, f := range filterFunctions {
+ if !f(rule) {
+ matching = false
+ break
+ }
+ }
+
+ for _, p := range rule.Patterns {
+ checkOccurrences := func(occurrences []PatternSlice) bool {
+ return (p.MinOccurrences == 0 || uint(len(occurrences)) >= p.MinOccurrences) &&
+ (p.MaxOccurrences == 0 || uint(len(occurrences)) <= p.MaxOccurrences)
+ }
+ clientOccurrences, clientPresent := clientMatches[p.internalID]
+ serverOccurrences, serverPresent := serverMatches[p.internalID]
+
+ if p.Direction == DirectionToServer {
+ if !clientPresent || !checkOccurrences(clientOccurrences) {
+ matching = false
+ break
+ }
+ } else if p.Direction == DirectionToClient {
+ if !serverPresent || !checkOccurrences(serverOccurrences) {
+ matching = false
+ break
+ }
+ } else {
+ if !(clientPresent || serverPresent) || !checkOccurrences(append(clientOccurrences, serverOccurrences...)) {
+ matching = false
+ break
+ }
+ }
+ }
+
+ if matching {
+ connection.MatchedRules = append(connection.MatchedRules, rule.ID)
+ }
+ }
+
+ rm.mutex.Unlock()
}
func (rm *rulesManagerImpl) DatabaseUpdateChannel() chan RulesDatabase {
@@ -226,7 +299,7 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error {
}
id := len(rm.patternsIds) + len(newPatterns)
- rule.Patterns[i].internalID = id
+ rule.Patterns[i].internalID = uint(id)
compiledPattern.Id = id
newPatterns = append(newPatterns, compiledPattern)
duplicatePatterns[regex] = true
@@ -235,7 +308,7 @@ func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error {
startId := len(rm.patterns)
for id, pattern := range newPatterns {
rm.patterns = append(rm.patterns, pattern)
- rm.patternsIds[pattern.String()] = startId + id
+ rm.patternsIds[pattern.String()] = uint(startId + id)
}
rm.rules[rule.ID] = *rule
diff --git a/rules_manager_test.go b/rules_manager_test.go
index 68e3ae4..59ce11c 100644
--- a/rules_manager_test.go
+++ b/rules_manager_test.go
@@ -13,25 +13,14 @@ func TestAddAndGetAllRules(t *testing.T) {
rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl)
- checkVersion := func(id RowID) {
- timeout := time.Tick(1 * time.Second)
-
- select {
- case database := <-rulesManager.DatabaseUpdateChannel():
- assert.Equal(t, id, database.version)
- case <-timeout:
- t.Fatal("timeout")
- }
- }
-
err := rulesManager.SetFlag(wrapper.Context, "FLAG{test}")
assert.NoError(t, err)
- checkVersion(rulesManager.rulesByName["flag"].ID)
+ checkVersion(t, rulesManager, rulesManager.rulesByName["flag"].ID)
emptyRule := Rule{Name: "empty", Color: "#fff", Enabled: true}
emptyID, err := rulesManager.AddRule(wrapper.Context, emptyRule)
assert.NoError(t, err)
assert.NotNil(t, emptyID)
- checkVersion(emptyID)
+ checkVersion(t, rulesManager, emptyID)
duplicateRule, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#eee"})
assert.Error(t, err)
@@ -72,7 +61,7 @@ func TestAddAndGetAllRules(t *testing.T) {
rule1ID, err := rulesManager.AddRule(wrapper.Context, rule1)
assert.NoError(t, err)
assert.NotNil(t, rule1ID)
- checkVersion(rule1ID)
+ checkVersion(t, rulesManager, rule1ID)
rule2 := Rule{
Name: "rule2",
@@ -86,7 +75,7 @@ func TestAddAndGetAllRules(t *testing.T) {
rule2ID, err := rulesManager.AddRule(wrapper.Context, rule2)
assert.NoError(t, err)
assert.NotNil(t, rule2ID)
- checkVersion(rule2ID)
+ checkVersion(t, rulesManager, rule2ID)
rule3 := Rule{
Name: "rule3",
@@ -100,7 +89,7 @@ func TestAddAndGetAllRules(t *testing.T) {
rule3ID, err := rulesManager.AddRule(wrapper.Context, rule3)
assert.NoError(t, err)
assert.NotNil(t, rule3ID)
- checkVersion(rule3ID)
+ checkVersion(t, rulesManager, rule3ID)
checkRule := func(expected Rule, patternIDs []int) {
var rule Rule
@@ -109,7 +98,7 @@ func TestAddAndGetAllRules(t *testing.T) {
require.NoError(t, err)
for i, id := range patternIDs {
- rule.Patterns[i].internalID = id
+ rule.Patterns[i].internalID = uint(id)
}
assert.Equal(t, expected, rule)
assert.Equal(t, expected, rulesManager.rules[expected.ID])
@@ -199,3 +188,97 @@ func TestLoadAndUpdateRules(t *testing.T) {
wrapper.Destroy(t)
}
+
+func TestFillWithMatchedRules(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ wrapper.AddCollection(Rules)
+
+ rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl)
+ err := rulesManager.SetFlag(wrapper.Context, "flag")
+ require.NoError(t, err)
+ checkVersion(t, rulesManager, rulesManager.rulesByName["flag"].ID)
+
+ emptyRule, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#fff"})
+ require.NoError(t, err)
+ checkVersion(t, rulesManager, emptyRule)
+
+ conn := &Connection{}
+ rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{}, map[uint][]PatternSlice{})
+ assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules)
+
+ filterRule, err := rulesManager.AddRule(wrapper.Context, Rule{
+ Name: "filter",
+ Color: "#fff",
+ Filter: Filter{
+ ServicePort: 80,
+ ClientAddress: "10.10.10.10",
+ ClientPort: 60000,
+ MinDuration: 2000,
+ MaxDuration: 4000,
+ MinBytes: 64,
+ MaxBytes: 64,
+ },
+ })
+ require.NoError(t, err)
+ checkVersion(t, rulesManager, filterRule)
+ conn = &Connection{
+ SourceIP: "10.10.10.10",
+ SourcePort: 60000,
+ DestinationPort: 80,
+ ClientBytes: 32,
+ ServerBytes: 32,
+ StartedAt: time.Now(),
+ ClosedAt: time.Now().Add(3 * time.Second),
+ }
+ rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{}, map[uint][]PatternSlice{})
+ assert.ElementsMatch(t, []RowID{emptyRule, filterRule}, conn.MatchedRules)
+
+ patternRule, err := rulesManager.AddRule(wrapper.Context, Rule{
+ Name: "pattern",
+ Color: "#fff",
+ Patterns: []Pattern{
+ {Regex: "pattern1", Direction: DirectionToClient, MinOccurrences: 1},
+ {Regex: "pattern2", Direction: DirectionToServer, MaxOccurrences: 2},
+ {Regex: "pattern3", Direction: DirectionBoth, MinOccurrences: 2, MaxOccurrences: 2},
+ },
+ })
+ require.NoError(t, err)
+ checkVersion(t, rulesManager, patternRule)
+ conn = &Connection{}
+ rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}, 3: {{0, 0}}},
+ map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0}}})
+ assert.ElementsMatch(t, []RowID{emptyRule, patternRule}, conn.MatchedRules)
+
+ rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}},
+ map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0},{0, 0}}})
+ assert.ElementsMatch(t, []RowID{emptyRule, patternRule}, conn.MatchedRules)
+
+ rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}, 3: {{0, 0},{0, 0}}},
+ map[uint][]PatternSlice{1: {{0, 0}}})
+ assert.ElementsMatch(t, []RowID{emptyRule, patternRule}, conn.MatchedRules)
+
+ rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}, 3: {{0, 0}}},
+ map[uint][]PatternSlice{3: {{0, 0}}})
+ assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules)
+
+ rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0},{0, 0}}, 3: {{0, 0}}},
+ map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0}}})
+ assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules)
+
+ rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0},{0, 0}}, 3: {{0, 0}}},
+ map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0},{0, 0}}})
+ assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules)
+
+ wrapper.Destroy(t)
+}
+
+func checkVersion(t *testing.T, rulesManager RulesManager, id RowID) {
+ timeout := time.Tick(1 * time.Second)
+
+ select {
+ case database := <-rulesManager.DatabaseUpdateChannel():
+ assert.Equal(t, id, database.version)
+ case <-timeout:
+ t.Fatal("timeout")
+ }
+}