aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEmiliano Ciavatta2020-04-21 10:57:10 +0000
committerEmiliano Ciavatta2020-04-21 10:57:10 +0000
commit516712f670803979c65fd3db73dd2de9d7175139 (patch)
tree44e701aee56031ea3a731352bd639ae1b060f83c
parent324623884309e95f541f285faee48150988ec466 (diff)
Add application_context
-rw-r--r--application_context.go92
-rw-r--r--application_router.go (renamed from routes.go)63
-rw-r--r--caronte.go20
-rw-r--r--rules_manager.go58
-rw-r--r--rules_manager_test.go32
-rw-r--r--storage.go2
6 files changed, 213 insertions, 54 deletions
diff --git a/application_context.go b/application_context.go
new file mode 100644
index 0000000..77fae4e
--- /dev/null
+++ b/application_context.go
@@ -0,0 +1,92 @@
+package main
+
+import (
+ "github.com/gin-gonic/gin"
+ log "github.com/sirupsen/logrus"
+ "net"
+)
+
+type Config struct {
+ ServerIP string `json:"server_ip" binding:"required,ip" bson:"server_ip"`
+ FlagRegex string `json:"flag_regex" binding:"required,min=8" bson:"flag_regex"`
+ AuthRequired bool `json:"auth_required" bson:"auth_required"`
+}
+
+type ApplicationContext struct {
+ Storage Storage
+ Config Config
+ Accounts gin.Accounts
+ RulesManager RulesManager
+ PcapImporter *PcapImporter
+ IsConfigured bool
+}
+
+func CreateApplicationContext(storage Storage) (*ApplicationContext, error) {
+ var configWrapper struct {
+ config Config
+ }
+ if err := storage.Find(Settings).Filter(OrderedDocument{{"_id", "config"}}).
+ First(&configWrapper); err != nil {
+ return nil, err
+ }
+ var accountsWrapper struct {
+ accounts gin.Accounts
+ }
+
+ if err := storage.Find(Settings).Filter(OrderedDocument{{"_id", "accounts"}}).
+ First(&accountsWrapper); err != nil {
+ return nil, err
+ }
+ if accountsWrapper.accounts == nil {
+ accountsWrapper.accounts = make(gin.Accounts)
+ }
+
+ applicationContext := &ApplicationContext{
+ Storage: storage,
+ Config: configWrapper.config,
+ Accounts: accountsWrapper.accounts,
+ }
+
+ applicationContext.configure()
+ return applicationContext, nil
+}
+
+func (sm *ApplicationContext) SetConfig(config Config) {
+ sm.Config = config
+ sm.configure()
+ var upsertResults interface{}
+ if _, err := sm.Storage.Update(Settings).Upsert(&upsertResults).
+ Filter(OrderedDocument{{"_id", "config"}}).One(UnorderedDocument{"config": config}); err != nil {
+ log.WithError(err).WithField("config", config).Error("failed to update config")
+ }
+}
+
+func (sm *ApplicationContext) SetAccounts(accounts gin.Accounts) {
+ sm.Accounts = accounts
+ var upsertResults interface{}
+ if _, err := sm.Storage.Update(Settings).Upsert(&upsertResults).
+ Filter(OrderedDocument{{"_id", "accounts"}}).One(UnorderedDocument{"accounts": accounts}); err != nil {
+ log.WithError(err).Error("failed to update accounts")
+ }
+}
+
+func (sm *ApplicationContext) configure() {
+ if sm.IsConfigured {
+ return
+ }
+ if sm.Config.ServerIP == "" || sm.Config.FlagRegex == "" {
+ return
+ }
+ serverIP := net.ParseIP(sm.Config.ServerIP)
+ if serverIP == nil {
+ return
+ }
+
+ rulesManager, err := LoadRulesManager(sm.Storage, sm.Config.FlagRegex)
+ if err != nil {
+ log.WithError(err).Panic("failed to create a RulesManager")
+ }
+ sm.RulesManager = rulesManager
+ sm.PcapImporter = NewPcapImporter(sm.Storage, serverIP, sm.RulesManager)
+ sm.IsConfigured = true
+}
diff --git a/routes.go b/application_router.go
index b628a9c..3ec74d5 100644
--- a/routes.go
+++ b/application_router.go
@@ -3,15 +3,39 @@ package main
import (
"github.com/gin-gonic/gin"
"net/http"
+ log "github.com/sirupsen/logrus"
)
-func ApplicationRoutes(engine *gin.Engine, rulesManager RulesManager) {
+func CreateApplicationRouter(applicationContext *ApplicationContext) *gin.Engine {
+ router := gin.New()
+ router.Use(gin.Logger())
+ router.Use(gin.Recovery())
+
// engine.Static("/", "./frontend/build")
- api := engine.Group("/api")
+ router.POST("/setup", func(c *gin.Context) {
+ var settings struct {
+ Config Config `json:"config"`
+ Accounts gin.Accounts `json:"accounts"`
+ }
+
+ if err := c.ShouldBindJSON(&settings); err != nil {
+ badRequest(c, err)
+ return
+ }
+
+ applicationContext.SetConfig(settings.Config)
+ applicationContext.SetAccounts(settings.Accounts)
+
+ c.JSON(http.StatusAccepted, gin.H{})
+ })
+
+ api := router.Group("/api")
+ api.Use(SetupRequiredMiddleware(applicationContext))
+ api.Use(AuthRequiredMiddleware(applicationContext))
{
api.GET("/rules", func(c *gin.Context) {
- success(c, rulesManager.GetRules())
+ success(c, applicationContext.RulesManager.GetRules())
})
api.POST("/rules", func(c *gin.Context) {
@@ -22,7 +46,7 @@ func ApplicationRoutes(engine *gin.Engine, rulesManager RulesManager) {
return
}
- if id, err := rulesManager.AddRule(c, rule); err != nil {
+ if id, err := applicationContext.RulesManager.AddRule(c, rule); err != nil {
unprocessableEntity(c, err)
} else {
success(c, UnorderedDocument{"id": id})
@@ -36,7 +60,7 @@ func ApplicationRoutes(engine *gin.Engine, rulesManager RulesManager) {
badRequest(c, err)
return
}
- rule, found := rulesManager.GetRule(id)
+ rule, found := applicationContext.RulesManager.GetRule(id)
if !found {
notFound(c, UnorderedDocument{"id": id})
} else {
@@ -57,7 +81,7 @@ func ApplicationRoutes(engine *gin.Engine, rulesManager RulesManager) {
return
}
- updated, err := rulesManager.UpdateRule(c, id, rule)
+ updated, err := applicationContext.RulesManager.UpdateRule(c, id, rule)
if err != nil {
badRequest(c, err)
} else if !updated {
@@ -67,6 +91,33 @@ func ApplicationRoutes(engine *gin.Engine, rulesManager RulesManager) {
}
})
}
+
+ return router
+}
+
+func SetupRequiredMiddleware(applicationContext *ApplicationContext) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ log.Error("aaaaaaaaaaaaaa")
+ if !applicationContext.IsConfigured {
+ c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{
+ "error": "setup required",
+ "url": c.Request.Host + "/setup",
+ })
+ } else {
+ c.Next()
+ }
+ }
+}
+
+func AuthRequiredMiddleware(applicationContext *ApplicationContext) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ if !applicationContext.Config.AuthRequired {
+ c.Next()
+ return
+ }
+
+ gin.BasicAuth(applicationContext.Accounts)(c)
+ }
}
func success(c *gin.Context, obj interface{}) {
diff --git a/caronte.go b/caronte.go
index f65247a..d365143 100644
--- a/caronte.go
+++ b/caronte.go
@@ -1,9 +1,9 @@
package main
import (
+ "context"
"flag"
"fmt"
- "github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
)
@@ -17,17 +17,19 @@ func main() {
flag.Parse()
+ logFields := log.Fields{"host": *mongoHost, "port": *mongoPort, "dbName": *dbName}
storage := NewMongoStorage(*mongoHost, *mongoPort, *dbName)
- err := storage.Connect(nil)
- if err != nil {
- log.WithError(err).Fatal("failed to connect to MongoDB")
+ if err := storage.Connect(context.Background()); err != nil {
+ log.WithError(err).WithFields(logFields).Fatal("failed to connect to MongoDB")
}
- rulesManager := NewRulesManager(storage)
- router := gin.Default()
- ApplicationRoutes(router, rulesManager)
- err = router.Run(fmt.Sprintf("%s:%v", *bindAddress, *bindPort))
+ applicationContext, err := CreateApplicationContext(storage)
if err != nil {
- log.WithError(err).Fatal("failed to create the server")
+ log.WithError(err).WithFields(logFields).Fatal("failed to create application context")
+ }
+
+ applicationRouter := CreateApplicationRouter(applicationContext)
+ if applicationRouter.Run(fmt.Sprintf("%s:%v", *bindAddress, *bindPort)) != nil {
+ log.WithError(err).WithFields(logFields).Fatal("failed to create the server")
}
}
diff --git a/rules_manager.go b/rules_manager.go
index fdd3eaa..015d8f7 100644
--- a/rules_manager.go
+++ b/rules_manager.go
@@ -25,12 +25,12 @@ type RegexFlags struct {
}
type Pattern struct {
- Regex string `json:"regex" binding:"required,min=1" bson:"regex"`
- Flags RegexFlags `json:"flags" bson:"flags,omitempty"`
- MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"`
- MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"`
- Direction uint8 `json:"direction" binding:"omitempty,max=2" bson:"direction,omitempty"`
- internalID uint
+ Regex string `json:"regex" binding:"required,min=1" bson:"regex"`
+ Flags RegexFlags `json:"flags" bson:"flags,omitempty"`
+ MinOccurrences uint `json:"min_occurrences" bson:"min_occurrences,omitempty"`
+ MaxOccurrences uint `json:"max_occurrences" binding:"omitempty,gtefield=MinOccurrences" bson:"max_occurrences,omitempty"`
+ Direction uint8 `json:"direction" binding:"omitempty,max=2" bson:"direction,omitempty"`
+ internalID uint
}
type Filter struct {
@@ -61,12 +61,10 @@ type RulesDatabase struct {
}
type RulesManager interface {
- LoadRules() error
AddRule(context context.Context, rule Rule) (RowID, error)
GetRule(id RowID) (Rule, bool)
UpdateRule(context context.Context, id RowID, rule Rule) (bool, error)
GetRules() []Rule
- SetFlag(context context.Context, flagRegex string) error
FillWithMatchedRules(connection *Connection, clientMatches map[uint][]PatternSlice, serverMatches map[uint][]PatternSlice)
DatabaseUpdateChannel() chan RulesDatabase
}
@@ -82,8 +80,13 @@ type rulesManagerImpl struct {
validate *validator.Validate
}
-func NewRulesManager(storage Storage) RulesManager {
- return &rulesManagerImpl{
+func LoadRulesManager(storage Storage, flagRegex string) (RulesManager, error) {
+ var rules []Rule
+ if err := storage.Find(Rules).Sort("_id", true).All(&rules); err != nil {
+ return nil, err
+ }
+
+ rulesManager := rulesManagerImpl{
storage: storage,
rules: make(map[RowID]Rule),
rulesByName: make(map[string]Rule),
@@ -93,21 +96,32 @@ func NewRulesManager(storage Storage) RulesManager {
databaseUpdated: make(chan RulesDatabase, 1),
validate: validator.New(),
}
-}
-func (rm *rulesManagerImpl) LoadRules() error {
- var rules []Rule
- if err := rm.storage.Find(Rules).Sort("_id", true).All(&rules); err != nil {
- return err
+ for _, rule := range rules {
+ if err := rulesManager.validateAndAddRuleLocal(&rule); err != nil {
+ return nil, err
+ }
}
- for _, rule := range rules {
- if err := rm.validateAndAddRuleLocal(&rule); err != nil {
- log.WithError(err).WithField("rule", rule).Warn("failed to import rule")
+ // if there are no rules in database (e.g. first run), set flagRegex as first rule
+ if len(rulesManager.rules) == 0 {
+ if _, err := rulesManager.AddRule(context.Background(), Rule{
+ Name: "flag",
+ Color: "#ff0000",
+ Notes: "Mark connections where the flag is stolen",
+ Patterns: []Pattern{
+ {Regex: flagRegex, Direction: DirectionToClient},
+ },
+ }); err != nil {
+ return nil, err
+ }
+ } else {
+ if err := rulesManager.generateDatabase(rules[len(rules)-1].ID); err != nil {
+ return nil, err
}
}
- return rm.generateDatabase(rules[len(rules)-1].ID)
+ return &rulesManager, nil
}
func (rm *rulesManagerImpl) AddRule(context context.Context, rule Rule) (RowID, error) {
@@ -197,7 +211,7 @@ func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientM
serverMatches map[uint][]PatternSlice) {
rm.mutex.Lock()
- filterFunctions := []func (rule Rule)bool {
+ filterFunctions := []func(rule Rule) bool{
func(rule Rule) bool {
return rule.Filter.ClientAddress == "" || connection.SourceIP == rule.Filter.ClientAddress
},
@@ -216,11 +230,11 @@ func (rm *rulesManagerImpl) FillWithMatchedRules(connection *Connection, clientM
rule.Filter.MaxDuration
},
func(rule Rule) bool {
- return rule.Filter.MinBytes == 0 || uint(connection.ClientBytes + connection.ServerBytes) >=
+ return rule.Filter.MinBytes == 0 || uint(connection.ClientBytes+connection.ServerBytes) >=
rule.Filter.MinBytes
},
func(rule Rule) bool {
- return rule.Filter.MaxBytes == 0 || uint(connection.ClientBytes + connection.ServerBytes) <=
+ return rule.Filter.MaxBytes == 0 || uint(connection.ClientBytes+connection.ServerBytes) <=
rule.Filter.MinBytes
},
}
diff --git a/rules_manager_test.go b/rules_manager_test.go
index 59ce11c..f06362b 100644
--- a/rules_manager_test.go
+++ b/rules_manager_test.go
@@ -11,11 +11,10 @@ func TestAddAndGetAllRules(t *testing.T) {
wrapper := NewTestStorageWrapper(t)
wrapper.AddCollection(Rules)
- rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl)
-
- err := rulesManager.SetFlag(wrapper.Context, "FLAG{test}")
- assert.NoError(t, err)
- checkVersion(t, rulesManager, rulesManager.rulesByName["flag"].ID)
+ rulesManager, err := LoadRulesManager(wrapper.Storage, "FLAG{test}")
+ require.NoError(t, err)
+ impl := rulesManager.(*rulesManagerImpl)
+ checkVersion(t, rulesManager, impl.rulesByName["flag"].ID)
emptyRule := Rule{Name: "empty", Color: "#fff", Enabled: true}
emptyID, err := rulesManager.AddRule(wrapper.Context, emptyRule)
assert.NoError(t, err)
@@ -101,14 +100,14 @@ func TestAddAndGetAllRules(t *testing.T) {
rule.Patterns[i].internalID = uint(id)
}
assert.Equal(t, expected, rule)
- assert.Equal(t, expected, rulesManager.rules[expected.ID])
- assert.Equal(t, expected, rulesManager.rulesByName[expected.Name])
+ assert.Equal(t, expected, impl.rules[expected.ID])
+ assert.Equal(t, expected, impl.rulesByName[expected.Name])
}
- assert.Len(t, rulesManager.rules, 5)
- assert.Len(t, rulesManager.rulesByName, 5)
- assert.Len(t, rulesManager.patterns, 5)
- assert.Len(t, rulesManager.patternsIds, 5)
+ assert.Len(t, impl.rules, 5)
+ assert.Len(t, impl.rulesByName, 5)
+ assert.Len(t, impl.patterns, 5)
+ assert.Len(t, impl.patternsIds, 5)
emptyRule.ID = emptyID
rule1.ID = rule1ID
@@ -121,7 +120,7 @@ func TestAddAndGetAllRules(t *testing.T) {
checkRule(rule3, []int{3, 4})
assert.Len(t, rulesManager.GetRules(), 5)
- assert.ElementsMatch(t, []Rule{rulesManager.rulesByName["flag"], emptyRule, rule1, rule2, rule3}, rulesManager.GetRules())
+ assert.ElementsMatch(t, []Rule{impl.rulesByName["flag"], emptyRule, rule1, rule2, rule3}, rulesManager.GetRules())
wrapper.Destroy(t)
}
@@ -150,8 +149,7 @@ func TestLoadAndUpdateRules(t *testing.T) {
require.NoError(t, err)
assert.ElementsMatch(t, expectedIds, ids)
- rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl)
- err = rulesManager.LoadRules()
+ rulesManager, err := LoadRulesManager(wrapper.Storage, "FLAG{nope}")
require.NoError(t, err)
rule, isPresent := rulesManager.GetRule(NewRowID())
@@ -193,10 +191,10 @@ func TestFillWithMatchedRules(t *testing.T) {
wrapper := NewTestStorageWrapper(t)
wrapper.AddCollection(Rules)
- rulesManager := NewRulesManager(wrapper.Storage).(*rulesManagerImpl)
- err := rulesManager.SetFlag(wrapper.Context, "flag")
+ rulesManager, err := LoadRulesManager(wrapper.Storage, "FLAG{test}")
require.NoError(t, err)
- checkVersion(t, rulesManager, rulesManager.rulesByName["flag"].ID)
+ impl := rulesManager.(*rulesManagerImpl)
+ checkVersion(t, rulesManager, impl.rulesByName["flag"].ID)
emptyRule, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#fff"})
require.NoError(t, err)
diff --git a/storage.go b/storage.go
index 5c77c6c..1a17723 100644
--- a/storage.go
+++ b/storage.go
@@ -16,6 +16,7 @@ const Connections = "connections"
const ConnectionStreams = "connection_streams"
const ImportingSessions = "importing_sessions"
const Rules = "rules"
+const Settings = "settings"
var ZeroRowID [12]byte
@@ -49,6 +50,7 @@ func NewMongoStorage(uri string, port int, database string) *MongoStorage {
ConnectionStreams: db.Collection(ConnectionStreams),
ImportingSessions: db.Collection(ImportingSessions),
Rules: db.Collection(Rules),
+ Settings: db.Collection(Settings),
}
return &MongoStorage{