diff options
author | Emiliano Ciavatta | 2020-04-21 10:57:10 +0000 |
---|---|---|
committer | Emiliano Ciavatta | 2020-04-21 10:57:10 +0000 |
commit | 516712f670803979c65fd3db73dd2de9d7175139 (patch) | |
tree | 44e701aee56031ea3a731352bd639ae1b060f83c | |
parent | 324623884309e95f541f285faee48150988ec466 (diff) |
Add application_context
-rw-r--r-- | application_context.go | 92 | ||||
-rw-r--r-- | application_router.go (renamed from routes.go) | 63 | ||||
-rw-r--r-- | caronte.go | 20 | ||||
-rw-r--r-- | rules_manager.go | 58 | ||||
-rw-r--r-- | rules_manager_test.go | 32 | ||||
-rw-r--r-- | storage.go | 2 |
6 files changed, 213 insertions, 54 deletions
diff --git a/application_context.go b/application_context.go new file mode 100644 index 0000000..77fae4e --- /dev/null +++ b/application_context.go @@ -0,0 +1,92 @@ +package main + +import ( + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" + "net" +) + +type Config struct { + ServerIP string `json:"server_ip" binding:"required,ip" bson:"server_ip"` + FlagRegex string `json:"flag_regex" binding:"required,min=8" bson:"flag_regex"` + AuthRequired bool `json:"auth_required" bson:"auth_required"` +} + +type ApplicationContext struct { + Storage Storage + Config Config + Accounts gin.Accounts + RulesManager RulesManager + PcapImporter *PcapImporter + IsConfigured bool +} + +func CreateApplicationContext(storage Storage) (*ApplicationContext, error) { + var configWrapper struct { + config Config + } + if err := storage.Find(Settings).Filter(OrderedDocument{{"_id", "config"}}). + First(&configWrapper); err != nil { + return nil, err + } + var accountsWrapper struct { + accounts gin.Accounts + } + + if err := storage.Find(Settings).Filter(OrderedDocument{{"_id", "accounts"}}). + First(&accountsWrapper); err != nil { + return nil, err + } + if accountsWrapper.accounts == nil { + accountsWrapper.accounts = make(gin.Accounts) + } + + applicationContext := &ApplicationContext{ + Storage: storage, + Config: configWrapper.config, + Accounts: accountsWrapper.accounts, + } + + applicationContext.configure() + return applicationContext, nil +} + +func (sm *ApplicationContext) SetConfig(config Config) { + sm.Config = config + sm.configure() + var upsertResults interface{} + if _, err := sm.Storage.Update(Settings).Upsert(&upsertResults). + Filter(OrderedDocument{{"_id", "config"}}).One(UnorderedDocument{"config": config}); err != nil { + log.WithError(err).WithField("config", config).Error("failed to update config") + } +} + +func (sm *ApplicationContext) SetAccounts(accounts gin.Accounts) { + sm.Accounts = accounts + var upsertResults interface{} + if _, err := sm.Storage.Update(Settings).Upsert(&upsertResults). + Filter(OrderedDocument{{"_id", "accounts"}}).One(UnorderedDocument{"accounts": accounts}); err != nil { + log.WithError(err).Error("failed to update accounts") + } +} + +func (sm *ApplicationContext) configure() { + if sm.IsConfigured { + return + } + if sm.Config.ServerIP == "" || sm.Config.FlagRegex == "" { + return + } + serverIP := net.ParseIP(sm.Config.ServerIP) + if serverIP == nil { + return + } + + rulesManager, err := LoadRulesManager(sm.Storage, sm.Config.FlagRegex) + if err != nil { + log.WithError(err).Panic("failed to create a RulesManager") + } + sm.RulesManager = rulesManager + sm.PcapImporter = NewPcapImporter(sm.Storage, serverIP, sm.RulesManager) + sm.IsConfigured = true +} diff --git a/routes.go b/application_router.go index b628a9c..3ec74d5 100644 --- a/routes.go +++ b/application_router.go @@ -3,15 +3,39 @@ package main import ( "github.com/gin-gonic/gin" "net/http" + log "github.com/sirupsen/logrus" ) -func ApplicationRoutes(engine *gin.Engine, rulesManager RulesManager) { +func CreateApplicationRouter(applicationContext *ApplicationContext) *gin.Engine { + router := gin.New() + router.Use(gin.Logger()) + router.Use(gin.Recovery()) + // engine.Static("/", "./frontend/build") - api := engine.Group("/api") + router.POST("/setup", func(c *gin.Context) { + var settings struct { + Config Config `json:"config"` + Accounts gin.Accounts `json:"accounts"` + } + + if err := c.ShouldBindJSON(&settings); err != nil { + badRequest(c, err) + return + } + + applicationContext.SetConfig(settings.Config) + applicationContext.SetAccounts(settings.Accounts) + + c.JSON(http.StatusAccepted, gin.H{}) + }) + + api := router.Group("/api") + api.Use(SetupRequiredMiddleware(applicationContext)) + api.Use(AuthRequiredMiddleware(applicationContext)) { api.GET("/rules", func(c *gin.Context) { - success(c, rulesManager.GetRules()) + success(c, applicationContext.RulesManager.GetRules()) }) api.POST("/rules", func(c *gin.Context) { @@ -22,7 +46,7 @@ func ApplicationRoutes(engine *gin.Engine, rulesManager RulesManager) { return } - if id, err := rulesManager.AddRule(c, rule); err != nil { + if id, err := applicationContext.RulesManager.AddRule(c, rule); err != nil { unprocessableEntity(c, err) } else { success(c, UnorderedDocument{"id": id}) @@ -36,7 +60,7 @@ func ApplicationRoutes(engine *gin.Engine, rulesManager RulesManager) { badRequest(c, err) return } - rule, found := rulesManager.GetRule(id) + rule, found := applicationContext.RulesManager.GetRule(id) if !found { notFound(c, UnorderedDocument{"id": id}) } else { @@ -57,7 +81,7 @@ func ApplicationRoutes(engine *gin.Engine, rulesManager RulesManager) { return } - updated, err := rulesManager.UpdateRule(c, id, rule) + updated, err := applicationContext.RulesManager.UpdateRule(c, id, rule) if err != nil { badRequest(c, err) } else if !updated { @@ -67,6 +91,33 @@ func ApplicationRoutes(engine *gin.Engine, rulesManager RulesManager) { } }) } + + return router +} + +func SetupRequiredMiddleware(applicationContext *ApplicationContext) gin.HandlerFunc { + return func(c *gin.Context) { + log.Error("aaaaaaaaaaaaaa") + if !applicationContext.IsConfigured { + c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{ + "error": "setup required", + "url": c.Request.Host + "/setup", + }) + } else { + c.Next() + } + } +} + +func AuthRequiredMiddleware(applicationContext *ApplicationContext) gin.HandlerFunc { + return func(c *gin.Context) { + if !applicationContext.Config.AuthRequired { + c.Next() + return + } + + gin.BasicAuth(applicationContext.Accounts)(c) + } } func success(c *gin.Context, obj interface{}) { @@ -1,9 +1,9 @@ package main import ( + "context" "flag" "fmt" - "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" ) @@ -17,17 +17,19 @@ func main() { flag.Parse() + logFields := log.Fields{"host": *mongoHost, "port": *mongoPort, "dbName": *dbName} storage := NewMongoStorage(*mongoHost, *mongoPort, *dbName) - err := storage.Connect(nil) - if err != nil { - log.WithError(err).Fatal("failed to connect to MongoDB") + if err := storage.Connect(context.Background()); err != nil { + log.WithError(err).WithFields(logFields).Fatal("failed to connect to MongoDB") } - rulesManager := NewRulesManager(storage) - router := gin.Default() - ApplicationRoutes(router, rulesManager) - err = router.Run(fmt.Sprintf("%s:%v", *bindAddress, *bindPort)) + applicationContext, err := CreateApplicationContext(storage) if err != nil { - log.WithError(err).Fatal("failed to create the server") + log.WithError(err).WithFields(logFields).Fatal("failed to create application context") + } + + applicationRouter := CreateApplicationRouter(applicationContext) + if applicationRouter.Run(fmt.Sprintf("%s:%v", *bindAddress, *bindPort)) != nil { + log.WithError(err).WithFields(logFields).Fatal("failed to create the server") } } diff --git a/rules_manager.go b/rules_manager.go index fdd3eaa..015d8f7 100644 --- a/rules_manager.go +++ b/rules_manager.go @@ -25,12 +25,12 @@ type RegexFlags struct { } type Pattern struct { - Regex string `json:"regex" binding:"required,min=1" bson:"regex"` - Flags RegexFlags `json:"flags" bson:"flags,omitempty"` - MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"` - MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"` - Direction uint8 `json:"direction" binding:"omitempty,max=2" bson:"direction,omitempty"` - internalID uint + Regex string `json:"regex" binding:"required,min=1" bson:"regex"` + Flags RegexFlags `json:"flags" bson:"flags,omitempty"` + MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"` + MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"` + Direction uint8 `json:"direction" binding:"omitempty,max=2" bson:"direction,omitempty"` + internalID uint } type Filter struct { @@ -61,12 +61,10 @@ type RulesDatabase struct { } type RulesManager interface { - LoadRules() error AddRule(context context.Context, rule Rule) (RowID, error) GetRule(id RowID) (Rule, bool) UpdateRule(context context.Context, id RowID, rule Rule) (bool, error) GetRules() []Rule - SetFlag(context context.Context, flagRegex string) error FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice) DatabaseUpdateChannel() chan RulesDatabase } @@ -82,8 +80,13 @@ type rulesManagerImpl struct { validate *validator.Validate } -func NewRulesManager(storage Storage) RulesManager { - return &rulesManagerImpl{ +func LoadRulesManager(storage Storage, flagRegex string) (RulesManager, error) { + var rules []Rule + if err := storage.Find(Rules).Sort("_id", true).All(&rules); err != nil { + return nil, err + } + + rulesManager := rulesManagerImpl{ storage: storage, rules: make(map[RowID]Rule), rulesByName: make(map[string]Rule), @@ -93,21 +96,32 @@ func NewRulesManager(storage Storage) RulesManager { databaseUpdated: make(chan RulesDatabase, 1), validate: validator.New(), } -} -func (rm *rulesManagerImpl) LoadRules() error { - var rules []Rule - if err := rm.storage.Find(Rules).Sort("_id", true).All(&rules); err != nil { - return err + for _, rule := range rules { + if err := rulesManager.validateAndAddRuleLocal(&rule); err != nil { + return nil, err + } } - for _, rule := range rules { - if err := rm.validateAndAddRuleLocal(&rule); err != nil { - log.WithError(err).WithField("rule", rule).Warn("failed to import rule") + // if there are no rules in database (e.g. first run), set flagRegex as first rule + if len(rulesManager.rules) == 0 { + if _, err := rulesManager.AddRule(context.Background(), Rule{ + Name: "flag", + Color: "#ff0000", + Notes: "Mark connections where the flag is stolen", + Patterns: []Pattern{ + {Regex: flagRegex, Direction: DirectionToClient}, + }, + }); err != nil { + return nil, err + } + } else { + if err := rulesManager.generateDatabase(rules[len(rules)-1].ID); err != nil { + return nil, err } } - return rm.generateDatabase(rules[len(rules)-1].ID) + return &rulesManager, nil } func (rm *rulesManagerImpl) AddRule(context context.Context, rule Rule) (RowID, error) { @@ -197,7 +211,7 @@ func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientM serverMatches map[uint][]PatternSlice) { rm.mutex.Lock() - filterFunctions := []func (rule Rule)bool { + filterFunctions := []func(rule Rule) bool{ func(rule Rule) bool { return rule.Filter.ClientAddress == "" || connection.SourceIP == rule.Filter.ClientAddress }, @@ -216,11 +230,11 @@ func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientM rule.Filter.MaxDuration }, func(rule Rule) bool { - return rule.Filter.MinBytes == 0 || uint(connection.ClientBytes + connection.ServerBytes) >= + return rule.Filter.MinBytes == 0 || uint(connection.ClientBytes+connection.ServerBytes) >= rule.Filter.MinBytes }, func(rule Rule) bool { - return rule.Filter.MaxBytes == 0 || uint(connection.ClientBytes + connection.ServerBytes) <= + return rule.Filter.MaxBytes == 0 || uint(connection.ClientBytes+connection.ServerBytes) <= rule.Filter.MinBytes }, } diff --git a/rules_manager_test.go b/rules_manager_test.go index 59ce11c..f06362b 100644 --- a/rules_manager_test.go +++ b/rules_manager_test.go @@ -11,11 +11,10 @@ func TestAddAndGetAllRules(t *testing.T) { wrapper := NewTestStorageWrapper(t) wrapper.AddCollection(Rules) - rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl) - - err := rulesManager.SetFlag(wrapper.Context, "FLAG{test}") - assert.NoError(t, err) - checkVersion(t, rulesManager, rulesManager.rulesByName["flag"].ID) + rulesManager, err := LoadRulesManager(wrapper.Storage, "FLAG{test}") + require.NoError(t, err) + impl := rulesManager.(*rulesManagerImpl) + checkVersion(t, rulesManager, impl.rulesByName["flag"].ID) emptyRule := Rule{Name: "empty", Color: "#fff", Enabled: true} emptyID, err := rulesManager.AddRule(wrapper.Context, emptyRule) assert.NoError(t, err) @@ -101,14 +100,14 @@ func TestAddAndGetAllRules(t *testing.T) { rule.Patterns[i].internalID = uint(id) } assert.Equal(t, expected, rule) - assert.Equal(t, expected, rulesManager.rules[expected.ID]) - assert.Equal(t, expected, rulesManager.rulesByName[expected.Name]) + assert.Equal(t, expected, impl.rules[expected.ID]) + assert.Equal(t, expected, impl.rulesByName[expected.Name]) } - assert.Len(t, rulesManager.rules, 5) - assert.Len(t, rulesManager.rulesByName, 5) - assert.Len(t, rulesManager.patterns, 5) - assert.Len(t, rulesManager.patternsIds, 5) + assert.Len(t, impl.rules, 5) + assert.Len(t, impl.rulesByName, 5) + assert.Len(t, impl.patterns, 5) + assert.Len(t, impl.patternsIds, 5) emptyRule.ID = emptyID rule1.ID = rule1ID @@ -121,7 +120,7 @@ func TestAddAndGetAllRules(t *testing.T) { checkRule(rule3, []int{3, 4}) assert.Len(t, rulesManager.GetRules(), 5) - assert.ElementsMatch(t, []Rule{rulesManager.rulesByName["flag"], emptyRule, rule1, rule2, rule3}, rulesManager.GetRules()) + assert.ElementsMatch(t, []Rule{impl.rulesByName["flag"], emptyRule, rule1, rule2, rule3}, rulesManager.GetRules()) wrapper.Destroy(t) } @@ -150,8 +149,7 @@ func TestLoadAndUpdateRules(t *testing.T) { require.NoError(t, err) assert.ElementsMatch(t, expectedIds, ids) - rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl) - err = rulesManager.LoadRules() + rulesManager, err := LoadRulesManager(wrapper.Storage, "FLAG{nope}") require.NoError(t, err) rule, isPresent := rulesManager.GetRule(NewRowID()) @@ -193,10 +191,10 @@ func TestFillWithMatchedRules(t *testing.T) { wrapper := NewTestStorageWrapper(t) wrapper.AddCollection(Rules) - rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl) - err := rulesManager.SetFlag(wrapper.Context, "flag") + rulesManager, err := LoadRulesManager(wrapper.Storage, "FLAG{test}") require.NoError(t, err) - checkVersion(t, rulesManager, rulesManager.rulesByName["flag"].ID) + impl := rulesManager.(*rulesManagerImpl) + checkVersion(t, rulesManager, impl.rulesByName["flag"].ID) emptyRule, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#fff"}) require.NoError(t, err) @@ -16,6 +16,7 @@ const Connections = "connections" const ConnectionStreams = "connection_streams" const ImportingSessions = "importing_sessions" const Rules = "rules" +const Settings = "settings" var ZeroRowID [12]byte @@ -49,6 +50,7 @@ func NewMongoStorage(uri string, port int, database string) *MongoStorage { ConnectionStreams: db.Collection(ConnectionStreams), ImportingSessions: db.Collection(ImportingSessions), Rules: db.Collection(Rules), + Settings: db.Collection(Settings), } return &MongoStorage{ |