aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEmiliano Ciavatta2020-04-13 10:29:00 +0000
committerEmiliano Ciavatta2020-04-13 10:29:00 +0000
commit7113463dead05631339fdab94de9440201c42489 (patch)
tree0da5b2ca4b7fc100bb843bb277a816aca0f42421
parent25bd17a2147d7169695772c2a887cdd54caff770 (diff)
Update rules
-rw-r--r--connection_handler_test.go46
-rw-r--r--routes.go29
-rw-r--r--rules_manager.go175
3 files changed, 161 insertions, 89 deletions
diff --git a/connection_handler_test.go b/connection_handler_test.go
index aefa058..20eb041 100644
--- a/connection_handler_test.go
+++ b/connection_handler_test.go
@@ -96,7 +96,7 @@ func TestConnectionFactory(t *testing.T) {
testInteraction := func(netFlow gopacket.Flow, transportFlow gopacket.Flow, otherSeenChan chan time.Time,
completed chan bool) {
- time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
+ time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond)
stream := factory.New(netFlow, transportFlow)
seen := time.Now()
stream.Reassembled([]tcpassembly.Reassembly{{[]byte{}, 0, true, true, seen}})
@@ -106,25 +106,25 @@ func TestConnectionFactory(t *testing.T) {
if netFlow == serverClientNetFlow {
otherSeenChan <- seen
return
+ }
+
+ otherSeen, ok := <-otherSeenChan
+ require.True(t, ok)
+
+ if seen.Before(otherSeen) {
+ startedAt = seen
+ closedAt = otherSeen
} else {
- otherSeen, ok := <-otherSeenChan
- require.True(t, ok)
-
- if seen.Before(otherSeen) {
- startedAt = seen
- closedAt = otherSeen
- } else {
- startedAt = otherSeen
- closedAt = seen
- }
+ startedAt = otherSeen
+ closedAt = seen
}
close(otherSeenChan)
var result Connection
connectionFlow := StreamFlow{netFlow.Src(), netFlow.Dst(), transportFlow.Src(), transportFlow.Dst()}
connectionID := wrapper.Storage.NewCustomRowID(connectionFlow.Hash(), startedAt)
- err = wrapper.Storage.Find(Connections).Context(wrapper.Context).
- Filter(OrderedDocument{{"_id", connectionID}}).First(&result)
+ op := wrapper.Storage.Find(Connections).Context(wrapper.Context)
+ err := op.Filter(OrderedDocument{{"_id", connectionID}}).First(&result)
require.NoError(t, err)
assert.NotNil(t, result)
@@ -140,7 +140,7 @@ func TestConnectionFactory(t *testing.T) {
}
completed := make(chan bool)
- n := 3000
+ n := 1000
for port := 40000; port < 40000+n; port++ {
clientPort := layers.NewTCPPortEndpoint(layers.TCPPort(port))
@@ -154,7 +154,7 @@ func TestConnectionFactory(t *testing.T) {
go testInteraction(serverClientNetFlow, serverClientTransportFlow, otherSeenChan, completed)
}
- timeout := time.Tick(1 * time.Second)
+ timeout := time.Tick(10 * time.Second)
for i := 0; i < n; i++ {
select {
case <- completed:
@@ -178,8 +178,20 @@ func (rm TestRuleManager) LoadRules() error {
return nil
}
-func (rm TestRuleManager) AddRule(_ context.Context, _ Rule) (string, error) {
- return "", nil
+func (rm TestRuleManager) AddRule(_ context.Context, _ Rule) (RowID, error) {
+ return RowID{}, nil
+}
+
+func (rm TestRuleManager) GetRule(_ RowID) (Rule, bool) {
+ return Rule{}, false
+}
+
+func (rm TestRuleManager) UpdateRule(_ context.Context, _ Rule) bool {
+ return false
+}
+
+func (rm TestRuleManager) GetRules() []Rule {
+ return nil
}
func (rm TestRuleManager) FillWithMatchedRules(_ *Connection, _ map[uint][]PatternSlice, _ map[uint][]PatternSlice) {
diff --git a/routes.go b/routes.go
index 3759382..37088e2 100644
--- a/routes.go
+++ b/routes.go
@@ -3,8 +3,8 @@ package main
import (
"fmt"
"github.com/gin-gonic/gin"
- "github.com/go-playground/validator/v10"
log "github.com/sirupsen/logrus"
+ "github.com/go-playground/validator/v10"
"net/http"
)
@@ -13,17 +13,38 @@ func ApplicationRoutes(engine *gin.Engine) {
api := engine.Group("/api")
{
- api.POST("/rules", func(c *gin.Context) {
+ api.POST("/rule", func(c *gin.Context) {
var rule Rule
+ //data, _ := c.GetRawData()
+ //
+ //var json = jsoniter.ConfigCompatibleWithStandardLibrary
+ //err := json.Unmarshal(data, &filter)
+ //
+ //if err != nil {
+ // log.WithError(err).Error("failed to unmarshal")
+ // c.String(500, "failed to unmarshal")
+ //}
+ //
+ //err = validator.New().Struct(filter)
+ //log.WithError(err).WithField("filter", filter).Error("aaaa")
+ //c.String(200, "ok")
+
if err := c.ShouldBindJSON(&rule); err != nil {
- for _, fieldErr := range err.(validator.ValidationErrors) {
+ validationErrors, ok := err.(validator.ValidationErrors)
+ if !ok {
+ log.WithError(err).WithField("rule", rule).Error("oops")
+ c.JSON(http.StatusBadRequest, gin.H{})
+ return
+ }
+
+ for _, fieldErr := range validationErrors {
log.Println(fieldErr)
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("field '%v' does not respect the %v(%v) rule",
fieldErr.Field(), fieldErr.Tag(), fieldErr.Param()),
})
- log.WithError(err).WithField("rule", rule).Panic("oops")
+ log.WithError(err).WithField("rule", rule).Error("oops")
return // exit on first error
}
}
diff --git a/rules_manager.go b/rules_manager.go
index 0750d53..388aeee 100644
--- a/rules_manager.go
+++ b/rules_manager.go
@@ -6,45 +6,44 @@ import (
"errors"
"fmt"
"github.com/flier/gohs/hyperscan"
+ "github.com/go-playground/validator/v10"
log "github.com/sirupsen/logrus"
"sync"
"time"
)
type RegexFlags struct {
- Caseless bool `json:"caseless"` // Set case-insensitive matching.
- DotAll bool `json:"dot_all"` // Matching a `.` will not exclude newlines.
- MultiLine bool `json:"multi_line"` // Set multi-line anchoring.
- SingleMatch bool `json:"single_match"` // Set single-match only mode.
- Utf8Mode bool `json:"utf_8_mode"` // Enable UTF-8 mode for this expression.
- UnicodeProperty bool `json:"unicode_property"` // Enable Unicode property support for this expression
+ Caseless bool `json:"caseless" bson:"caseless,omitempty"` // Set case-insensitive matching.
+ DotAll bool `json:"dot_all" bson:"dot_all,omitempty"` // Matching a `.` will not exclude newlines.
+ MultiLine bool `json:"multi_line" bson:"multi_line,omitempty"` // Set multi-line anchoring.
+ SingleMatch bool `json:"single_match" bson:"single_match,omitempty"` // Set single-match only mode.
+ Utf8Mode bool `json:"utf_8_mode" bson:"utf_8_mode,omitempty"` // Enable UTF-8 mode for this expression.
+ UnicodeProperty bool `json:"unicode_property" bson:"unicode_property,omitempty"` // Enable Unicode property support for this expression
}
type Pattern struct {
- Regex string `json:"regex"`
- Flags RegexFlags `json:"flags"`
- MinOccurrences int `json:"min_occurrences"`
- MaxOccurrences int `json:"max_occurrences"`
+ Regex string `json:"regex" binding:"min=1" bson:"regex"`
+ Flags RegexFlags `json:"flags" bson:"flags,omitempty"`
+ MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"`
+ MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"`
internalID int
compiledPattern *hyperscan.Pattern
}
type Filter struct {
- ServicePort int
- ClientAddress string
- ClientPort int
- MinDuration int
- MaxDuration int
- MinPackets int
- MaxPackets int
- MinSize int
- MaxSize int
+ ServicePort uint16 `json:"service_port" bson:"service_port,omitempty"`
+ ClientAddress string `json:"client_address" binding:"ip_addr" bson:"client_address,omitempty"`
+ ClientPort uint16 `json:"client_port" bson:"client_port,omitempty"`
+ MinDuration uint `json:"min_duration" bson:"min_duration,omitempty"`
+ MaxDuration uint `json:"max_duration" binding:"omitempty,gtefield=MinDuration" bson:"max_duration,omitempty"`
+ MinBytes uint `json:"min_bytes" bson:"min_bytes,omitempty"`
+ MaxBytes uint `json:"max_bytes" binding:"omitempty,gtefield=MinBytes" bson:"max_bytes,omitempty"`
}
type Rule struct {
- ID RowID `json:"-" bson:"_id,omitempty"`
- Name string `json:"name" binding:"required,min=3" bson:"name"`
- Color string `json:"color" binding:"required,hexcolor" bson:"color"`
+ ID RowID `json:"id" bson:"_id,omitempty"`
+ Name string `json:"name" binding:"min=3" bson:"name"`
+ Color string `json:"color" binding:"hexcolor" bson:"color"`
Notes string `json:"notes" bson:"notes,omitempty"`
Enabled bool `json:"enabled" bson:"enabled"`
Patterns []Pattern `json:"patterns" binding:"required,min=1" bson:"patterns"`
@@ -53,38 +52,44 @@ type Rule struct {
}
type RulesDatabase struct {
- database hyperscan.StreamDatabase
+ database hyperscan.StreamDatabase
databaseSize int
- version RowID
+ version RowID
}
type RulesManager interface {
LoadRules() error
- AddRule(context context.Context, rule Rule) (string, error)
+ AddRule(context context.Context, rule Rule) (RowID, error)
+ GetRule(id RowID) (Rule, bool)
+ UpdateRule(context context.Context, rule Rule) bool
+ GetRules() []Rule
FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice)
DatabaseUpdateChannel() chan RulesDatabase
}
type rulesManagerImpl struct {
storage Storage
- rules map[string]Rule
+ rules map[RowID]Rule
rulesByName map[string]Rule
- ruleIndex int
patterns map[string]Pattern
- mPatterns sync.Mutex
+ mutex sync.Mutex
databaseUpdated chan RulesDatabase
+ validate *validator.Validate
}
func NewRulesManager(storage Storage) RulesManager {
return &rulesManagerImpl{
- storage: storage,
- rules: make(map[string]Rule),
- patterns: make(map[string]Pattern),
- mPatterns: sync.Mutex{},
+ storage: storage,
+ rules: make(map[RowID]Rule),
+ rulesByName: make(map[string]Rule),
+ patterns: make(map[string]Pattern),
+ mutex: sync.Mutex{},
+ databaseUpdated: make(chan RulesDatabase),
+ validate: validator.New(),
}
}
-func (rm rulesManagerImpl) LoadRules() error {
+func (rm *rulesManagerImpl) LoadRules() error {
var rules []Rule
if err := rm.storage.Find(Rules).Sort("_id", true).All(&rules); err != nil {
return err
@@ -96,48 +101,96 @@ func (rm rulesManagerImpl) LoadRules() error {
}
}
- rm.ruleIndex = len(rules)
return rm.generateDatabase(rules[len(rules)-1].ID)
}
-func (rm rulesManagerImpl) AddRule(context context.Context, rule Rule) (string, error) {
- rm.mPatterns.Lock()
+func (rm *rulesManagerImpl) AddRule(context context.Context, rule Rule) (RowID, error) {
+ rm.mutex.Lock()
- rule.ID = rm.storage.NewCustomRowID(uint64(rm.ruleIndex), time.Now())
+ rule.ID = rm.storage.NewCustomRowID(uint64(len(rm.rules)), time.Now())
rule.Enabled = true
if err := rm.validateAndAddRuleLocal(&rule); err != nil {
- rm.mPatterns.Unlock()
- return "", err
+ rm.mutex.Unlock()
+ return rule.ID, err
}
if err := rm.generateDatabase(rule.ID); err != nil {
- rm.mPatterns.Unlock()
+ rm.mutex.Unlock()
log.WithError(err).WithField("rule", rule).Panic("failed to generate database")
}
- rm.mPatterns.Unlock()
+ rm.mutex.Unlock()
if _, err := rm.storage.Insert(Rules).Context(context).One(rule); err != nil {
log.WithError(err).WithField("rule", rule).Panic("failed to insert rule on database")
}
- return rule.ID.Hex(), nil
+ return rule.ID, nil
}
-func (rm rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error {
+func (rm *rulesManagerImpl) GetRule(id RowID) (Rule, bool) {
+ rule, isPresent := rm.rules[id]
+ return rule, isPresent
+}
+
+func (rm *rulesManagerImpl) UpdateRule(context context.Context, rule Rule) bool {
+ updated, err := rm.storage.Update(Rules).Context(context).Filter(OrderedDocument{{"_id", rule.ID}}).
+ One(UnorderedDocument{"name": rule.Name, "color": rule.Color})
+ if err != nil {
+ log.WithError(err).WithField("rule", rule).Panic("failed to update rule on database")
+ }
+
+ if updated {
+ rm.mutex.Lock()
+ newRule := rm.rules[rule.ID]
+ newRule.Name = rule.Name
+ newRule.Color = rule.Color
+
+ delete(rm.rulesByName, newRule.Name)
+ rm.rulesByName[rule.Name] = newRule
+ rm.rules[rule.ID] = newRule
+ rm.mutex.Unlock()
+ }
+
+ return updated
+}
+
+func (rm *rulesManagerImpl) GetRules() []Rule {
+ rules := make([]Rule, 0, len(rm.rules))
+
+ for _, rule := range rm.rules {
+ rules = append(rules, rule)
+ }
+
+ return rules
+}
+
+func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice,
+ serverMatches map[uint][]PatternSlice) {
+}
+
+func (rm *rulesManagerImpl) DatabaseUpdateChannel() chan RulesDatabase {
+ return rm.databaseUpdated
+}
+
+func (rm *rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error {
if _, alreadyPresent := rm.rulesByName[rule.Name]; alreadyPresent {
return errors.New("rule name must be unique")
}
newPatterns := make(map[string]Pattern)
for i, pattern := range rule.Patterns {
+ if err := rm.validate.Struct(pattern); err != nil {
+ return err
+ }
+
hash := pattern.Hash()
if existingPattern, isPresent := rm.patterns[hash]; isPresent {
rule.Patterns[i] = existingPattern
continue
}
- err := pattern.BuildPattern()
- if err != nil {
+
+ if err := pattern.BuildPattern(); err != nil {
return err
}
pattern.internalID = len(rm.patterns) + len(newPatterns)
@@ -148,18 +201,16 @@ func (rm rulesManagerImpl) validateAndAddRuleLocal(rule *Rule) error {
rm.patterns[key] = value
}
- rm.rules[rule.ID.Hex()] = *rule
+ rm.rules[rule.ID] = *rule
rm.rulesByName[rule.Name] = *rule
return nil
}
-func (rm rulesManagerImpl) generateDatabase(version RowID) error {
- patterns := make([]*hyperscan.Pattern, len(rm.patterns))
- var i int
+func (rm *rulesManagerImpl) generateDatabase(version RowID) error {
+ patterns := make([]*hyperscan.Pattern, 0, len(rm.patterns))
for _, pattern := range rm.patterns {
- patterns[i] = pattern.compiledPattern
- i++
+ patterns = append(patterns, pattern.compiledPattern)
}
database, err := hyperscan.NewStreamDatabase(patterns...)
if err != nil {
@@ -167,31 +218,17 @@ func (rm rulesManagerImpl) generateDatabase(version RowID) error {
}
rm.databaseUpdated <- RulesDatabase{
- database: database,
+ database: database,
databaseSize: len(patterns),
- version: version,
+ version: version,
}
return nil
}
-func (rm rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice,
- serverMatches map[uint][]PatternSlice) {
-}
-
-func (rm rulesManagerImpl) DatabaseUpdateChannel() chan RulesDatabase {
- return rm.databaseUpdated
-}
-
-func (p Pattern) BuildPattern() error {
+func (p *Pattern) BuildPattern() error {
if p.compiledPattern != nil {
return nil
}
- if p.MinOccurrences <= 0 {
- return errors.New("min_occurrences can't be lower than zero")
- }
- if p.MaxOccurrences != -1 && p.MinOccurrences < p.MinOccurrences {
- return errors.New("max_occurrences can't be lower than min_occurrences")
- }
hp, err := hyperscan.ParsePattern(fmt.Sprintf("/%s/", p.Regex))
if err != nil {
@@ -221,6 +258,8 @@ func (p Pattern) BuildPattern() error {
return errors.New("can't validate the pattern")
}
+ p.compiledPattern = hp
+
return nil
}