From f8dd01e9cd59ff7c0920eab3dd65c02a15de059e Mon Sep 17 00:00:00 2001 From: Emiliano Ciavatta Date: Tue, 21 Apr 2020 19:26:23 +0200 Subject: Add application_router tests; change connection_stream key hash method --- application_router.go | 32 +++++++++++++++- application_router_test.go | 92 ++++++++++++++++++++++++++++++++++++++++++++++ connection_handler.go | 8 +++- pcap_importer.go | 4 ++ pcap_importer_test.go | 5 +++ stream_handler.go | 6 +-- utils.go | 16 ++++++++ 7 files changed, 156 insertions(+), 7 deletions(-) create mode 100644 application_router_test.go diff --git a/application_router.go b/application_router.go index d79fee8..bdc8007 100644 --- a/application_router.go +++ b/application_router.go @@ -1,6 +1,7 @@ package main import ( + "errors" "github.com/gin-gonic/gin" "net/http" ) @@ -13,9 +14,14 @@ func CreateApplicationRouter(applicationContext *ApplicationContext) *gin.Engine // engine.Static("/", "./frontend/build") router.POST("/setup", func(c *gin.Context) { + if applicationContext.IsConfigured { + c.AbortWithStatus(http.StatusNotFound) + return + } + var settings struct { - Config Config `json:"config"` - Accounts gin.Accounts `json:"accounts"` + Config Config `json:"config" binding:"required"` + Accounts gin.Accounts `json:"accounts" binding:"required"` } if err := c.ShouldBindJSON(&settings); err != nil { @@ -89,6 +95,28 @@ func CreateApplicationRouter(applicationContext *ApplicationContext) *gin.Engine success(c, rule) } }) + + api.POST("/pcap/file", func(c *gin.Context) { + var body struct { + Path string + } + + if err := c.ShouldBindJSON(&body); err != nil { + badRequest(c, err) + return + } + + if !FileExists(body.Path) { + unprocessableEntity(c, errors.New("invalid path")) + return + } + + if sessionID, err := applicationContext.PcapImporter.ImportPcap(body.Path); err != nil { + unprocessableEntity(c, err) + } else { + c.JSON(http.StatusAccepted, gin.H{"session": sessionID}) + } + }) } return router diff --git a/application_router_test.go b/application_router_test.go new file mode 100644 index 0000000..ece38f1 --- /dev/null +++ b/application_router_test.go @@ -0,0 +1,92 @@ +package main + +import ( + "bytes" + "encoding/json" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func TestSetupApplication(t *testing.T) { + toolkit := NewRouterTestToolkit(t, false) + + settings := make(map[string]interface{}) + assert.Equal(t, http.StatusServiceUnavailable, toolkit.MakeRequest("GET", "/api/rules", nil).Code) + assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("POST", "/setup", settings).Code) + settings["config"] = Config{ServerIP: "1.2.3.4", FlagRegex: "FLAG{test}", AuthRequired: true} + assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("POST", "/setup", settings).Code) + settings["accounts"] = gin.Accounts{"username": "password"} + assert.Equal(t, http.StatusAccepted, toolkit.MakeRequest("POST", "/setup", settings).Code) + assert.Equal(t, http.StatusNotFound, toolkit.MakeRequest("POST", "/setup", settings).Code) + + toolkit.wrapper.Destroy(t) +} + +func TestAuthRequired(t *testing.T) { + toolkit := NewRouterTestToolkit(t, true) + + assert.Equal(t, http.StatusOK, toolkit.MakeRequest("GET", "/api/rules", nil).Code) + config := toolkit.appContext.Config + config.AuthRequired = true + toolkit.appContext.SetConfig(config) + toolkit.appContext.SetAccounts(gin.Accounts{"username": "password"}) + assert.Equal(t, http.StatusUnauthorized, toolkit.MakeRequest("GET", "/api/rules", nil).Code) + + toolkit.wrapper.Destroy(t) +} + +type RouterTestToolkit struct { + appContext *ApplicationContext + wrapper *TestStorageWrapper + router *gin.Engine + t *testing.T +} + +func NewRouterTestToolkit(t *testing.T, withSetup bool) *RouterTestToolkit { + wrapper := NewTestStorageWrapper(t) + wrapper.AddCollection(Settings) + + appContext, err := CreateApplicationContext(wrapper.Storage) + require.NoError(t, err) + gin.SetMode(gin.ReleaseMode) + router := CreateApplicationRouter(appContext) + + toolkit := RouterTestToolkit{ + appContext: appContext, + wrapper: wrapper, + router: router, + t: t, + } + + if withSetup { + settings := gin.H{ + "config": Config{ServerIP: "1.2.3.4", FlagRegex: "FLAG{test}", AuthRequired: false}, + "accounts": gin.Accounts{}, + } + toolkit.MakeRequest("POST", "/setup", settings) + } + + return &toolkit +} + +func (rtt *RouterTestToolkit) MakeRequest(method string, url string, body interface{}) *httptest.ResponseRecorder { + var r io.Reader + + if body != nil { + buf, err := json.Marshal(body) + require.NoError(rtt.t, err) + r = bytes.NewBuffer(buf) + } + + w := httptest.NewRecorder() + req, err := http.NewRequest(method, url, r) + require.NoError(rtt.t, err) + rtt.router.ServeHTTP(w, req) + + return w +} diff --git a/connection_handler.go b/connection_handler.go index e4730cc..53e594f 100644 --- a/connection_handler.go +++ b/connection_handler.go @@ -6,6 +6,7 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/tcpassembly" log "github.com/sirupsen/logrus" + "hash/fnv" "sync" "time" ) @@ -237,5 +238,10 @@ func (ch *connectionHandlerImpl) PatternsDatabaseSize() int { } func (sf StreamFlow) Hash() uint64 { - return sf[0].FastHash() ^ sf[1].FastHash() ^ sf[2].FastHash() ^ sf[3].FastHash() + hash := fnv.New64a() + _, _ = hash.Write(sf[0].Raw()) + _, _ = hash.Write(sf[1].Raw()) + _, _ = hash.Write(sf[2].Raw()) + _, _ = hash.Write(sf[3].Raw()) + return hash.Sum64() } diff --git a/pcap_importer.go b/pcap_importer.go index bb09867..24ce2cf 100644 --- a/pcap_importer.go +++ b/pcap_importer.go @@ -29,6 +29,8 @@ type PcapImporter struct { type ImportingSession struct { ID string `json:"id" bson:"_id"` + StartedAt time.Time `json:"started_at" bson:"started_at"` + Size int64 `json:"size" bson:"size"` CompletedAt time.Time `json:"completed_at" bson:"completed_at,omitempty"` ProcessedPackets int `json:"processed_packets" bson:"processed_packets"` InvalidPackets int `json:"invalid_packets" bson:"invalid_packets"` @@ -75,6 +77,8 @@ func (pi *PcapImporter) ImportPcap(fileName string) (string, error) { ctx, cancelFunc := context.WithCancel(context.Background()) session := ImportingSession{ ID: hash, + StartedAt: time.Now(), + Size: FileSize(fileName), PacketsPerService: make(map[uint16]flowCount), cancelFunc: cancelFunc, completed: make(chan string), diff --git a/pcap_importer_test.go b/pcap_importer_test.go index bda2cb2..8d03e4e 100644 --- a/pcap_importer_test.go +++ b/pcap_importer_test.go @@ -53,6 +53,7 @@ func TestCancelImportSession(t *testing.T) { session := waitSessionCompletion(t, pcapImporter, sessionID) assert.Zero(t, session.CompletedAt) + assert.Equal(t, 1270696, session.Size) assert.Equal(t, 0, session.ProcessedPackets) assert.Equal(t, 0, session.InvalidPackets) assert.Equal(t, map[uint16]flowCount{}, session.PacketsPerService) @@ -71,6 +72,7 @@ func TestImportNoTcpPackets(t *testing.T) { require.NoError(t, err) session := waitSessionCompletion(t, pcapImporter, sessionID) + assert.Equal(t, 228024, session.Size) assert.Equal(t, 2000, session.ProcessedPackets) assert.Equal(t, 2000, session.InvalidPackets) assert.Equal(t, map[uint16]flowCount{}, session.PacketsPerService) @@ -114,7 +116,10 @@ func checkSessionEquals(t *testing.T, wrapper *TestStorageWrapper, session Impor var result ImportingSession assert.NoError(t, wrapper.Storage.Find(ImportingSessions).Filter(OrderedDocument{{"_id", session.ID}}). Context(wrapper.Context).First(&result)) + assert.Equal(t, session.StartedAt.Unix(), result.StartedAt.Unix()) assert.Equal(t, session.CompletedAt.Unix(), result.CompletedAt.Unix()) + session.StartedAt = time.Time{} + result.StartedAt = time.Time{} session.CompletedAt = time.Time{} result.CompletedAt = time.Time{} session.cancelFunc = nil diff --git a/stream_handler.go b/stream_handler.go index a436fd5..97975fa 100644 --- a/stream_handler.go +++ b/stream_handler.go @@ -149,7 +149,7 @@ func (sh *StreamHandler) storageCurrentDocument() { payload := sh.streamFlow.Hash()&uint64(0xffffffffffffff00) | uint64(len(sh.documentsIDs)) // LOL streamID := CustomRowID(payload, sh.firstPacketSeen) - _, err := sh.connection.Storage().Insert(ConnectionStreams). + if _, err := sh.connection.Storage().Insert(ConnectionStreams). One(ConnectionStream{ ID: streamID, ConnectionID: ZeroRowID, @@ -159,9 +159,7 @@ func (sh *StreamHandler) storageCurrentDocument() { BlocksTimestamps: sh.timestamps, BlocksLoss: sh.lossBlocks, PatternMatches: sh.patternMatches, - }) - - if err != nil { + }); err != nil { log.WithError(err).Error("failed to insert connection stream") } else { sh.documentsIDs = append(sh.documentsIDs, streamID) diff --git a/utils.go b/utils.go index b9cdd8c..de83ecb 100644 --- a/utils.go +++ b/utils.go @@ -57,3 +57,19 @@ func RowIDFromHex(hex string) (RowID, error) { rowID, err := primitive.ObjectIDFromHex(hex) return rowID, err } + +func FileExists(filename string) bool { + info, err := os.Stat(filename) + if os.IsNotExist(err) { + return false + } + return !info.IsDir() +} + +func FileSize(filename string) int64 { + info, err := os.Stat(filename) + if os.IsNotExist(err) { + return -1 + } + return info.Size() +} -- cgit v1.2.3-70-g09d2