From 516712f670803979c65fd3db73dd2de9d7175139 Mon Sep 17 00:00:00 2001 From: Emiliano Ciavatta Date: Tue, 21 Apr 2020 12:57:10 +0200 Subject: Add application_context --- application_context.go | 92 +++++++++++++++++++++++++++++ application_router.go | 154 +++++++++++++++++++++++++++++++++++++++++++++++++ caronte.go | 20 ++++--- routes.go | 103 --------------------------------- rules_manager.go | 58 ++++++++++++------- rules_manager_test.go | 32 +++++----- storage.go | 2 + 7 files changed, 310 insertions(+), 151 deletions(-) create mode 100644 application_context.go create mode 100644 application_router.go delete mode 100644 routes.go diff --git a/application_context.go b/application_context.go new file mode 100644 index 0000000..77fae4e --- /dev/null +++ b/application_context.go @@ -0,0 +1,92 @@ +package main + +import ( + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" + "net" +) + +type Config struct { + ServerIP string `json:"server_ip" binding:"required,ip" bson:"server_ip"` + FlagRegex string `json:"flag_regex" binding:"required,min=8" bson:"flag_regex"` + AuthRequired bool `json:"auth_required" bson:"auth_required"` +} + +type ApplicationContext struct { + Storage Storage + Config Config + Accounts gin.Accounts + RulesManager RulesManager + PcapImporter *PcapImporter + IsConfigured bool +} + +func CreateApplicationContext(storage Storage) (*ApplicationContext, error) { + var configWrapper struct { + config Config + } + if err := storage.Find(Settings).Filter(OrderedDocument{{"_id", "config"}}). + First(&configWrapper); err != nil { + return nil, err + } + var accountsWrapper struct { + accounts gin.Accounts + } + + if err := storage.Find(Settings).Filter(OrderedDocument{{"_id", "accounts"}}). + First(&accountsWrapper); err != nil { + return nil, err + } + if accountsWrapper.accounts == nil { + accountsWrapper.accounts = make(gin.Accounts) + } + + applicationContext := &ApplicationContext{ + Storage: storage, + Config: configWrapper.config, + Accounts: accountsWrapper.accounts, + } + + applicationContext.configure() + return applicationContext, nil +} + +func (sm *ApplicationContext) SetConfig(config Config) { + sm.Config = config + sm.configure() + var upsertResults interface{} + if _, err := sm.Storage.Update(Settings).Upsert(&upsertResults). + Filter(OrderedDocument{{"_id", "config"}}).One(UnorderedDocument{"config": config}); err != nil { + log.WithError(err).WithField("config", config).Error("failed to update config") + } +} + +func (sm *ApplicationContext) SetAccounts(accounts gin.Accounts) { + sm.Accounts = accounts + var upsertResults interface{} + if _, err := sm.Storage.Update(Settings).Upsert(&upsertResults). + Filter(OrderedDocument{{"_id", "accounts"}}).One(UnorderedDocument{"accounts": accounts}); err != nil { + log.WithError(err).Error("failed to update accounts") + } +} + +func (sm *ApplicationContext) configure() { + if sm.IsConfigured { + return + } + if sm.Config.ServerIP == "" || sm.Config.FlagRegex == "" { + return + } + serverIP := net.ParseIP(sm.Config.ServerIP) + if serverIP == nil { + return + } + + rulesManager, err := LoadRulesManager(sm.Storage, sm.Config.FlagRegex) + if err != nil { + log.WithError(err).Panic("failed to create a RulesManager") + } + sm.RulesManager = rulesManager + sm.PcapImporter = NewPcapImporter(sm.Storage, serverIP, sm.RulesManager) + sm.IsConfigured = true +} diff --git a/application_router.go b/application_router.go new file mode 100644 index 0000000..3ec74d5 --- /dev/null +++ b/application_router.go @@ -0,0 +1,154 @@ +package main + +import ( + "github.com/gin-gonic/gin" + "net/http" + log "github.com/sirupsen/logrus" +) + +func CreateApplicationRouter(applicationContext *ApplicationContext) *gin.Engine { + router := gin.New() + router.Use(gin.Logger()) + router.Use(gin.Recovery()) + + // engine.Static("/", "./frontend/build") + + router.POST("/setup", func(c *gin.Context) { + var settings struct { + Config Config `json:"config"` + Accounts gin.Accounts `json:"accounts"` + } + + if err := c.ShouldBindJSON(&settings); err != nil { + badRequest(c, err) + return + } + + applicationContext.SetConfig(settings.Config) + applicationContext.SetAccounts(settings.Accounts) + + c.JSON(http.StatusAccepted, gin.H{}) + }) + + api := router.Group("/api") + api.Use(SetupRequiredMiddleware(applicationContext)) + api.Use(AuthRequiredMiddleware(applicationContext)) + { + api.GET("/rules", func(c *gin.Context) { + success(c, applicationContext.RulesManager.GetRules()) + }) + + api.POST("/rules", func(c *gin.Context) { + var rule Rule + + if err := c.ShouldBindJSON(&rule); err != nil { + badRequest(c, err) + return + } + + if id, err := applicationContext.RulesManager.AddRule(c, rule); err != nil { + unprocessableEntity(c, err) + } else { + success(c, UnorderedDocument{"id": id}) + } + }) + + api.GET("/rules/:id", func(c *gin.Context) { + hex := c.Param("id") + id, err := RowIDFromHex(hex) + if err != nil { + badRequest(c, err) + return + } + rule, found := applicationContext.RulesManager.GetRule(id) + if !found { + notFound(c, UnorderedDocument{"id": id}) + } else { + success(c, rule) + } + }) + + api.PUT("/rules/:id", func(c *gin.Context) { + hex := c.Param("id") + id, err := RowIDFromHex(hex) + if err != nil { + badRequest(c, err) + return + } + var rule Rule + if err := c.ShouldBindJSON(&rule); err != nil { + badRequest(c, err) + return + } + + updated, err := applicationContext.RulesManager.UpdateRule(c, id, rule) + if err != nil { + badRequest(c, err) + } else if !updated { + notFound(c, UnorderedDocument{"id": id}) + } else { + success(c, rule) + } + }) + } + + return router +} + +func SetupRequiredMiddleware(applicationContext *ApplicationContext) gin.HandlerFunc { + return func(c *gin.Context) { + log.Error("aaaaaaaaaaaaaa") + if !applicationContext.IsConfigured { + c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{ + "error": "setup required", + "url": c.Request.Host + "/setup", + }) + } else { + c.Next() + } + } +} + +func AuthRequiredMiddleware(applicationContext *ApplicationContext) gin.HandlerFunc { + return func(c *gin.Context) { + if !applicationContext.Config.AuthRequired { + c.Next() + return + } + + gin.BasicAuth(applicationContext.Accounts)(c) + } +} + +func success(c *gin.Context, obj interface{}) { + c.JSON(http.StatusOK, obj) +} + +func badRequest(c *gin.Context, err error) { + c.JSON(http.StatusBadRequest, UnorderedDocument{"result": "error", "error": err.Error()}) + + //validationErrors, ok := err.(validator.ValidationErrors) + //if !ok { + // log.WithError(err).WithField("rule", rule).Error("oops") + // c.JSON(http.StatusBadRequest, gin.H{}) + // return + //} + // + //for _, fieldErr := range validationErrors { + // log.Println(fieldErr) + // c.JSON(http.StatusBadRequest, gin.H{ + // "error": fmt.Sprintf("field '%v' does not respect the %v(%v) rule", + // fieldErr.Field(), fieldErr.Tag(), fieldErr.Param()), + // }) + // log.WithError(err).WithField("rule", rule).Error("oops") + // return // exit on first error + //} +} + +func unprocessableEntity(c *gin.Context, err error) { + c.JSON(http.StatusUnprocessableEntity, UnorderedDocument{"result": "error", "error": err.Error()}) +} + +func notFound(c *gin.Context, obj interface{}) { + c.JSON(http.StatusNotFound, obj) +} diff --git a/caronte.go b/caronte.go index f65247a..d365143 100644 --- a/caronte.go +++ b/caronte.go @@ -1,9 +1,9 @@ package main import ( + "context" "flag" "fmt" - "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" ) @@ -17,17 +17,19 @@ func main() { flag.Parse() + logFields := log.Fields{"host": *mongoHost, "port": *mongoPort, "dbName": *dbName} storage := NewMongoStorage(*mongoHost, *mongoPort, *dbName) - err := storage.Connect(nil) - if err != nil { - log.WithError(err).Fatal("failed to connect to MongoDB") + if err := storage.Connect(context.Background()); err != nil { + log.WithError(err).WithFields(logFields).Fatal("failed to connect to MongoDB") } - rulesManager := NewRulesManager(storage) - router := gin.Default() - ApplicationRoutes(router, rulesManager) - err = router.Run(fmt.Sprintf("%s:%v", *bindAddress, *bindPort)) + applicationContext, err := CreateApplicationContext(storage) if err != nil { - log.WithError(err).Fatal("failed to create the server") + log.WithError(err).WithFields(logFields).Fatal("failed to create application context") + } + + applicationRouter := CreateApplicationRouter(applicationContext) + if applicationRouter.Run(fmt.Sprintf("%s:%v", *bindAddress, *bindPort)) != nil { + log.WithError(err).WithFields(logFields).Fatal("failed to create the server") } } diff --git a/routes.go b/routes.go deleted file mode 100644 index b628a9c..0000000 --- a/routes.go +++ /dev/null @@ -1,103 +0,0 @@ -package main - -import ( - "github.com/gin-gonic/gin" - "net/http" -) - -func ApplicationRoutes(engine *gin.Engine, rulesManager RulesManager) { - // engine.Static("/", "./frontend/build") - - api := engine.Group("/api") - { - api.GET("/rules", func(c *gin.Context) { - success(c, rulesManager.GetRules()) - }) - - api.POST("/rules", func(c *gin.Context) { - var rule Rule - - if err := c.ShouldBindJSON(&rule); err != nil { - badRequest(c, err) - return - } - - if id, err := rulesManager.AddRule(c, rule); err != nil { - unprocessableEntity(c, err) - } else { - success(c, UnorderedDocument{"id": id}) - } - }) - - api.GET("/rules/:id", func(c *gin.Context) { - hex := c.Param("id") - id, err := RowIDFromHex(hex) - if err != nil { - badRequest(c, err) - return - } - rule, found := rulesManager.GetRule(id) - if !found { - notFound(c, UnorderedDocument{"id": id}) - } else { - success(c, rule) - } - }) - - api.PUT("/rules/:id", func(c *gin.Context) { - hex := c.Param("id") - id, err := RowIDFromHex(hex) - if err != nil { - badRequest(c, err) - return - } - var rule Rule - if err := c.ShouldBindJSON(&rule); err != nil { - badRequest(c, err) - return - } - - updated, err := rulesManager.UpdateRule(c, id, rule) - if err != nil { - badRequest(c, err) - } else if !updated { - notFound(c, UnorderedDocument{"id": id}) - } else { - success(c, rule) - } - }) - } -} - -func success(c *gin.Context, obj interface{}) { - c.JSON(http.StatusOK, obj) -} - -func badRequest(c *gin.Context, err error) { - c.JSON(http.StatusBadRequest, UnorderedDocument{"result": "error", "error": err.Error()}) - - //validationErrors, ok := err.(validator.ValidationErrors) - //if !ok { - // log.WithError(err).WithField("rule", rule).Error("oops") - // c.JSON(http.StatusBadRequest, gin.H{}) - // return - //} - // - //for _, fieldErr := range validationErrors { - // log.Println(fieldErr) - // c.JSON(http.StatusBadRequest, gin.H{ - // "error": fmt.Sprintf("field '%v' does not respect the %v(%v) rule", - // fieldErr.Field(), fieldErr.Tag(), fieldErr.Param()), - // }) - // log.WithError(err).WithField("rule", rule).Error("oops") - // return // exit on first error - //} -} - -func unprocessableEntity(c *gin.Context, err error) { - c.JSON(http.StatusUnprocessableEntity, UnorderedDocument{"result": "error", "error": err.Error()}) -} - -func notFound(c *gin.Context, obj interface{}) { - c.JSON(http.StatusNotFound, obj) -} diff --git a/rules_manager.go b/rules_manager.go index fdd3eaa..015d8f7 100644 --- a/rules_manager.go +++ b/rules_manager.go @@ -25,12 +25,12 @@ type RegexFlags struct { } type Pattern struct { - Regex string `json:"regex" binding:"required,min=1" bson:"regex"` - Flags RegexFlags `json:"flags" bson:"flags,omitempty"` - MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"` - MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"` - Direction uint8 `json:"direction" binding:"omitempty,max=2" bson:"direction,omitempty"` - internalID uint + Regex string `json:"regex" binding:"required,min=1" bson:"regex"` + Flags RegexFlags `json:"flags" bson:"flags,omitempty"` + MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"` + MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"` + Direction uint8 `json:"direction" binding:"omitempty,max=2" bson:"direction,omitempty"` + internalID uint } type Filter struct { @@ -61,12 +61,10 @@ type RulesDatabase struct { } type RulesManager interface { - LoadRules() error AddRule(context context.Context, rule Rule) (RowID, error) GetRule(id RowID) (Rule, bool) UpdateRule(context context.Context, id RowID, rule Rule) (bool, error) GetRules() []Rule - SetFlag(context context.Context, flagRegex string) error FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice) DatabaseUpdateChannel() chan RulesDatabase } @@ -82,8 +80,13 @@ type rulesManagerImpl struct { validate *validator.Validate } -func NewRulesManager(storage Storage) RulesManager { - return &rulesManagerImpl{ +func LoadRulesManager(storage Storage, flagRegex string) (RulesManager, error) { + var rules []Rule + if err := storage.Find(Rules).Sort("_id", true).All(&rules); err != nil { + return nil, err + } + + rulesManager := rulesManagerImpl{ storage: storage, rules: make(map[RowID]Rule), rulesByName: make(map[string]Rule), @@ -93,21 +96,32 @@ func NewRulesManager(storage Storage) RulesManager { databaseUpdated: make(chan RulesDatabase, 1), validate: validator.New(), } -} -func (rm *rulesManagerImpl) LoadRules() error { - var rules []Rule - if err := rm.storage.Find(Rules).Sort("_id", true).All(&rules); err != nil { - return err + for _, rule := range rules { + if err := rulesManager.validateAndAddRuleLocal(&rule); err != nil { + return nil, err + } } - for _, rule := range rules { - if err := rm.validateAndAddRuleLocal(&rule); err != nil { - log.WithError(err).WithField("rule", rule).Warn("failed to import rule") + // if there are no rules in database (e.g. first run), set flagRegex as first rule + if len(rulesManager.rules) == 0 { + if _, err := rulesManager.AddRule(context.Background(), Rule{ + Name: "flag", + Color: "#ff0000", + Notes: "Mark connections where the flag is stolen", + Patterns: []Pattern{ + {Regex: flagRegex, Direction: DirectionToClient}, + }, + }); err != nil { + return nil, err + } + } else { + if err := rulesManager.generateDatabase(rules[len(rules)-1].ID); err != nil { + return nil, err } } - return rm.generateDatabase(rules[len(rules)-1].ID) + return &rulesManager, nil } func (rm *rulesManagerImpl) AddRule(context context.Context, rule Rule) (RowID, error) { @@ -197,7 +211,7 @@ func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientM serverMatches map[uint][]PatternSlice) { rm.mutex.Lock() - filterFunctions := []func (rule Rule)bool { + filterFunctions := []func(rule Rule) bool{ func(rule Rule) bool { return rule.Filter.ClientAddress == "" || connection.SourceIP == rule.Filter.ClientAddress }, @@ -216,11 +230,11 @@ func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientM rule.Filter.MaxDuration }, func(rule Rule) bool { - return rule.Filter.MinBytes == 0 || uint(connection.ClientBytes + connection.ServerBytes) >= + return rule.Filter.MinBytes == 0 || uint(connection.ClientBytes+connection.ServerBytes) >= rule.Filter.MinBytes }, func(rule Rule) bool { - return rule.Filter.MaxBytes == 0 || uint(connection.ClientBytes + connection.ServerBytes) <= + return rule.Filter.MaxBytes == 0 || uint(connection.ClientBytes+connection.ServerBytes) <= rule.Filter.MinBytes }, } diff --git a/rules_manager_test.go b/rules_manager_test.go index 59ce11c..f06362b 100644 --- a/rules_manager_test.go +++ b/rules_manager_test.go @@ -11,11 +11,10 @@ func TestAddAndGetAllRules(t *testing.T) { wrapper := NewTestStorageWrapper(t) wrapper.AddCollection(Rules) - rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl) - - err := rulesManager.SetFlag(wrapper.Context, "FLAG{test}") - assert.NoError(t, err) - checkVersion(t, rulesManager, rulesManager.rulesByName["flag"].ID) + rulesManager, err := LoadRulesManager(wrapper.Storage, "FLAG{test}") + require.NoError(t, err) + impl := rulesManager.(*rulesManagerImpl) + checkVersion(t, rulesManager, impl.rulesByName["flag"].ID) emptyRule := Rule{Name: "empty", Color: "#fff", Enabled: true} emptyID, err := rulesManager.AddRule(wrapper.Context, emptyRule) assert.NoError(t, err) @@ -101,14 +100,14 @@ func TestAddAndGetAllRules(t *testing.T) { rule.Patterns[i].internalID = uint(id) } assert.Equal(t, expected, rule) - assert.Equal(t, expected, rulesManager.rules[expected.ID]) - assert.Equal(t, expected, rulesManager.rulesByName[expected.Name]) + assert.Equal(t, expected, impl.rules[expected.ID]) + assert.Equal(t, expected, impl.rulesByName[expected.Name]) } - assert.Len(t, rulesManager.rules, 5) - assert.Len(t, rulesManager.rulesByName, 5) - assert.Len(t, rulesManager.patterns, 5) - assert.Len(t, rulesManager.patternsIds, 5) + assert.Len(t, impl.rules, 5) + assert.Len(t, impl.rulesByName, 5) + assert.Len(t, impl.patterns, 5) + assert.Len(t, impl.patternsIds, 5) emptyRule.ID = emptyID rule1.ID = rule1ID @@ -121,7 +120,7 @@ func TestAddAndGetAllRules(t *testing.T) { checkRule(rule3, []int{3, 4}) assert.Len(t, rulesManager.GetRules(), 5) - assert.ElementsMatch(t, []Rule{rulesManager.rulesByName["flag"], emptyRule, rule1, rule2, rule3}, rulesManager.GetRules()) + assert.ElementsMatch(t, []Rule{impl.rulesByName["flag"], emptyRule, rule1, rule2, rule3}, rulesManager.GetRules()) wrapper.Destroy(t) } @@ -150,8 +149,7 @@ func TestLoadAndUpdateRules(t *testing.T) { require.NoError(t, err) assert.ElementsMatch(t, expectedIds, ids) - rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl) - err = rulesManager.LoadRules() + rulesManager, err := LoadRulesManager(wrapper.Storage, "FLAG{nope}") require.NoError(t, err) rule, isPresent := rulesManager.GetRule(NewRowID()) @@ -193,10 +191,10 @@ func TestFillWithMatchedRules(t *testing.T) { wrapper := NewTestStorageWrapper(t) wrapper.AddCollection(Rules) - rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl) - err := rulesManager.SetFlag(wrapper.Context, "flag") + rulesManager, err := LoadRulesManager(wrapper.Storage, "FLAG{test}") require.NoError(t, err) - checkVersion(t, rulesManager, rulesManager.rulesByName["flag"].ID) + impl := rulesManager.(*rulesManagerImpl) + checkVersion(t, rulesManager, impl.rulesByName["flag"].ID) emptyRule, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#fff"}) require.NoError(t, err) diff --git a/storage.go b/storage.go index 5c77c6c..1a17723 100644 --- a/storage.go +++ b/storage.go @@ -16,6 +16,7 @@ const Connections = "connections" const ConnectionStreams = "connection_streams" const ImportingSessions = "importing_sessions" const Rules = "rules" +const Settings = "settings" var ZeroRowID [12]byte @@ -49,6 +50,7 @@ func NewMongoStorage(uri string, port int, database string) *MongoStorage { ConnectionStreams: db.Collection(ConnectionStreams), ImportingSessions: db.Collection(ImportingSessions), Rules: db.Collection(Rules), + Settings: db.Collection(Settings), } return &MongoStorage{ -- cgit v1.2.3-70-g09d2