From 468690c60ee2e57ed2ccb4375e9ada5d2fed9473 Mon Sep 17 00:00:00 2001 From: Emiliano Ciavatta Date: Tue, 7 Apr 2020 20:46:36 +0200 Subject: Before storage refactor --- rules_manager.go | 242 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 242 insertions(+) create mode 100644 rules_manager.go (limited to 'rules_manager.go') diff --git a/rules_manager.go b/rules_manager.go new file mode 100644 index 0000000..d6e9aaa --- /dev/null +++ b/rules_manager.go @@ -0,0 +1,242 @@ +package main + +import ( + "context" + "crypto/sha256" + "encoding/json" + "errors" + "fmt" + "github.com/flier/gohs/hyperscan" + "go.mongodb.org/mongo-driver/bson/primitive" + "log" + "sync" + "time" +) + + +type RegexFlags struct { + Caseless bool `json:"caseless"` // Set case-insensitive matching. + DotAll bool `json:"dot_all"` // Matching a `.` will not exclude newlines. + MultiLine bool `json:"multi_line"` // Set multi-line anchoring. + SingleMatch bool `json:"single_match"` // Set single-match only mode. + Utf8Mode bool `json:"utf_8_mode"` // Enable UTF-8 mode for this expression. + UnicodeProperty bool `json:"unicode_property"` // Enable Unicode property support for this expression +} + +type Pattern struct { + Regex string `json:"regex"` + Flags RegexFlags `json:"flags"` + MinOccurrences int `json:"min_occurrences"` + MaxOccurrences int `json:"max_occurrences"` + internalId int + compiledPattern *hyperscan.Pattern +} + +type Filter struct { + ServicePort int + ClientAddress string + ClientPort int + MinDuration int + MaxDuration int + MinPackets int + MaxPackets int + MinSize int + MaxSize int +} + +type Rule struct { + Id string `json:"-" bson:"_id,omitempty"` + Name string `json:"name" binding:"required,min=3" bson:"name"` + Color string `json:"color" binding:"required,hexcolor" bson:"color"` + Notes string `json:"notes" bson:"notes,omitempty"` + Enabled bool `json:"enabled" bson:"enabled"` + Patterns []Pattern `json:"patterns" binding:"required,min=1" bson:"patterns"` + Filter Filter `json:"filter" bson:"filter,omitempty"` + Version int64 `json:"version" bson:"version"` +} + +type RulesManager struct { + storage Storage + rules map[string]Rule + rulesByName map[string]Rule + ruleIndex int + patterns map[string]Pattern + mPatterns sync.Mutex + databaseUpdated chan interface{} +} + +func NewRulesManager(storage Storage) RulesManager { + return RulesManager{ + storage: storage, + rules: make(map[string]Rule), + patterns: make(map[string]Pattern), + mPatterns: sync.Mutex{}, + } +} + + + +func (rm RulesManager) LoadRules() error { + var rules []Rule + if err := rm.storage.Find(nil, Rules, NoFilters, &rules); err != nil { + return err + } + + var version int64 + for _, rule := range rules { + if err := rm.validateAndAddRuleLocal(&rule); err != nil { + log.Printf("failed to import rule %s: %s\n", rule.Name, err) + continue + } + if rule.Version > version { + version = rule.Version + } + } + + rm.ruleIndex = len(rules) + return rm.generateDatabase(0) +} + +func (rm RulesManager) AddRule(context context.Context, rule Rule) (string, error) { + rm.mPatterns.Lock() + + rule.Id = UniqueKey(time.Now(), uint32(rm.ruleIndex)) + rule.Enabled = true + + if err := rm.validateAndAddRuleLocal(&rule); err != nil { + rm.mPatterns.Unlock() + return "", err + } + rm.mPatterns.Unlock() + + if _, err := rm.storage.InsertOne(context, Rules, rule); err != nil { + return "", err + } + + return rule.Id, rm.generateDatabase(rule.Id) +} + + + +func (rm RulesManager) validateAndAddRuleLocal(rule *Rule) error { + if _, alreadyPresent := rm.rulesByName[rule.Name]; alreadyPresent { + return errors.New("rule name must be unique") + } + + newPatterns := make(map[string]Pattern) + for i, pattern := range rule.Patterns { + hash := pattern.Hash() + if existingPattern, isPresent := rm.patterns[hash]; isPresent { + rule.Patterns[i] = existingPattern + continue + } + err := pattern.BuildPattern() + if err != nil { + return err + } + pattern.internalId = len(rm.patterns) + len(newPatterns) + newPatterns[hash] = pattern + } + + for key, value := range newPatterns { + rm.patterns[key] = value + } + + rm.rules[rule.Id] = *rule + rm.rulesByName[rule.Name] = *rule + + return nil +} + +func (rm RulesManager) generateDatabase(version string) error { + patterns := make([]*hyperscan.Pattern, len(rm.patterns)) + for _, pattern := range rm.patterns { + patterns = append(patterns, pattern.compiledPattern) + } + database, err := hyperscan.NewStreamDatabase(patterns...) + if err != nil { + return err + } + + rm.databaseUpdated <- database + return nil +} + + +func (p Pattern) BuildPattern() error { + if p.compiledPattern != nil { + return nil + } + if p.MinOccurrences <= 0 { + return errors.New("min_occurrences can't be lower than zero") + } + if p.MaxOccurrences != -1 && p.MinOccurrences < p.MinOccurrences { + return errors.New("max_occurrences can't be lower than min_occurrences") + } + + hp, err := hyperscan.ParsePattern(fmt.Sprintf("/%s/", p.Regex)) + if err != nil { + return err + } + + if p.Flags.Caseless { + hp.Flags |= hyperscan.Caseless + } + if p.Flags.DotAll { + hp.Flags |= hyperscan.DotAll + } + if p.Flags.MultiLine { + hp.Flags |= hyperscan.MultiLine + } + if p.Flags.SingleMatch { + hp.Flags |= hyperscan.SingleMatch + } + if p.Flags.Utf8Mode { + hp.Flags |= hyperscan.Utf8Mode + } + if p.Flags.UnicodeProperty { + hp.Flags |= hyperscan.UnicodeProperty + } + + if !hp.IsValid() { + return errors.New("can't validate the pattern") + } + + return nil +} + +func (p Pattern) Hash() string { + hash := sha256.New() + hash.Write([]byte(fmt.Sprintf("%s|%v|%v|%v", p.Regex, p.Flags, p.MinOccurrences, p.MaxOccurrences))) + return fmt.Sprintf("%x", hash.Sum(nil)) +} + +func test() { + user := &Pattern{Regex: "Frank"} + b, err := json.Marshal(user) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(b)) + + p, _ := hyperscan.ParsePattern("/a/") + p1, _ := hyperscan.ParsePattern("/a/") + fmt.Println(p1.String(), p1.Flags) + //p1.Id = 1 + + fmt.Println(*p == *p1) + db, _ := hyperscan.NewBlockDatabase(p, p1) + s, _ := hyperscan.NewScratch(db) + db.Scan([]byte("Ciao"), s, onMatch, nil) + + + + +} + +func onMatch(id uint, from uint64, to uint64, flags uint, context interface{}) error { + fmt.Println(id) + + return nil +} \ No newline at end of file -- cgit v1.2.3-70-g09d2