aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--application_router.go32
-rw-r--r--application_router_test.go92
-rw-r--r--connection_handler.go8
-rw-r--r--pcap_importer.go4
-rw-r--r--pcap_importer_test.go5
-rw-r--r--stream_handler.go6
-rw-r--r--utils.go16
7 files changed, 156 insertions, 7 deletions
diff --git a/application_router.go b/application_router.go
index d79fee8..bdc8007 100644
--- a/application_router.go
+++ b/application_router.go
@@ -1,6 +1,7 @@
package main
import (
+ "errors"
"github.com/gin-gonic/gin"
"net/http"
)
@@ -13,9 +14,14 @@ func CreateApplicationRouter(applicationContext *ApplicationContext) *gin.Engine
// engine.Static("/", "./frontend/build")
router.POST("/setup", func(c *gin.Context) {
+ if applicationContext.IsConfigured {
+ c.AbortWithStatus(http.StatusNotFound)
+ return
+ }
+
var settings struct {
- Config Config `json:"config"`
- Accounts gin.Accounts `json:"accounts"`
+ Config Config `json:"config" binding:"required"`
+ Accounts gin.Accounts `json:"accounts" binding:"required"`
}
if err := c.ShouldBindJSON(&settings); err != nil {
@@ -89,6 +95,28 @@ func CreateApplicationRouter(applicationContext *ApplicationContext) *gin.Engine
success(c, rule)
}
})
+
+ api.POST("/pcap/file", func(c *gin.Context) {
+ var body struct {
+ Path string
+ }
+
+ if err := c.ShouldBindJSON(&body); err != nil {
+ badRequest(c, err)
+ return
+ }
+
+ if !FileExists(body.Path) {
+ unprocessableEntity(c, errors.New("invalid path"))
+ return
+ }
+
+ if sessionID, err := applicationContext.PcapImporter.ImportPcap(body.Path); err != nil {
+ unprocessableEntity(c, err)
+ } else {
+ c.JSON(http.StatusAccepted, gin.H{"session": sessionID})
+ }
+ })
}
return router
diff --git a/application_router_test.go b/application_router_test.go
new file mode 100644
index 0000000..ece38f1
--- /dev/null
+++ b/application_router_test.go
@@ -0,0 +1,92 @@
+package main
+
+import (
+ "bytes"
+ "encoding/json"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func TestSetupApplication(t *testing.T) {
+ toolkit := NewRouterTestToolkit(t, false)
+
+ settings := make(map[string]interface{})
+ assert.Equal(t, http.StatusServiceUnavailable, toolkit.MakeRequest("GET", "/api/rules", nil).Code)
+ assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("POST", "/setup", settings).Code)
+ settings["config"] = Config{ServerIP: "1.2.3.4", FlagRegex: "FLAG{test}", AuthRequired: true}
+ assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("POST", "/setup", settings).Code)
+ settings["accounts"] = gin.Accounts{"username": "password"}
+ assert.Equal(t, http.StatusAccepted, toolkit.MakeRequest("POST", "/setup", settings).Code)
+ assert.Equal(t, http.StatusNotFound, toolkit.MakeRequest("POST", "/setup", settings).Code)
+
+ toolkit.wrapper.Destroy(t)
+}
+
+func TestAuthRequired(t *testing.T) {
+ toolkit := NewRouterTestToolkit(t, true)
+
+ assert.Equal(t, http.StatusOK, toolkit.MakeRequest("GET", "/api/rules", nil).Code)
+ config := toolkit.appContext.Config
+ config.AuthRequired = true
+ toolkit.appContext.SetConfig(config)
+ toolkit.appContext.SetAccounts(gin.Accounts{"username": "password"})
+ assert.Equal(t, http.StatusUnauthorized, toolkit.MakeRequest("GET", "/api/rules", nil).Code)
+
+ toolkit.wrapper.Destroy(t)
+}
+
+type RouterTestToolkit struct {
+ appContext *ApplicationContext
+ wrapper *TestStorageWrapper
+ router *gin.Engine
+ t *testing.T
+}
+
+func NewRouterTestToolkit(t *testing.T, withSetup bool) *RouterTestToolkit {
+ wrapper := NewTestStorageWrapper(t)
+ wrapper.AddCollection(Settings)
+
+ appContext, err := CreateApplicationContext(wrapper.Storage)
+ require.NoError(t, err)
+ gin.SetMode(gin.ReleaseMode)
+ router := CreateApplicationRouter(appContext)
+
+ toolkit := RouterTestToolkit{
+ appContext: appContext,
+ wrapper: wrapper,
+ router: router,
+ t: t,
+ }
+
+ if withSetup {
+ settings := gin.H{
+ "config": Config{ServerIP: "1.2.3.4", FlagRegex: "FLAG{test}", AuthRequired: false},
+ "accounts": gin.Accounts{},
+ }
+ toolkit.MakeRequest("POST", "/setup", settings)
+ }
+
+ return &toolkit
+}
+
+func (rtt *RouterTestToolkit) MakeRequest(method string, url string, body interface{}) *httptest.ResponseRecorder {
+ var r io.Reader
+
+ if body != nil {
+ buf, err := json.Marshal(body)
+ require.NoError(rtt.t, err)
+ r = bytes.NewBuffer(buf)
+ }
+
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(method, url, r)
+ require.NoError(rtt.t, err)
+ rtt.router.ServeHTTP(w, req)
+
+ return w
+}
diff --git a/connection_handler.go b/connection_handler.go
index e4730cc..53e594f 100644
--- a/connection_handler.go
+++ b/connection_handler.go
@@ -6,6 +6,7 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/tcpassembly"
log "github.com/sirupsen/logrus"
+ "hash/fnv"
"sync"
"time"
)
@@ -237,5 +238,10 @@ func (ch *connectionHandlerImpl) PatternsDatabaseSize() int {
}
func (sf StreamFlow) Hash() uint64 {
- return sf[0].FastHash() ^ sf[1].FastHash() ^ sf[2].FastHash() ^ sf[3].FastHash()
+ hash := fnv.New64a()
+ _, _ = hash.Write(sf[0].Raw())
+ _, _ = hash.Write(sf[1].Raw())
+ _, _ = hash.Write(sf[2].Raw())
+ _, _ = hash.Write(sf[3].Raw())
+ return hash.Sum64()
}
diff --git a/pcap_importer.go b/pcap_importer.go
index bb09867..24ce2cf 100644
--- a/pcap_importer.go
+++ b/pcap_importer.go
@@ -29,6 +29,8 @@ type PcapImporter struct {
type ImportingSession struct {
ID string `json:"id" bson:"_id"`
+ StartedAt time.Time `json:"started_at" bson:"started_at"`
+ Size int64 `json:"size" bson:"size"`
CompletedAt time.Time `json:"completed_at" bson:"completed_at,omitempty"`
ProcessedPackets int `json:"processed_packets" bson:"processed_packets"`
InvalidPackets int `json:"invalid_packets" bson:"invalid_packets"`
@@ -75,6 +77,8 @@ func (pi *PcapImporter) ImportPcap(fileName string) (string, error) {
ctx, cancelFunc := context.WithCancel(context.Background())
session := ImportingSession{
ID: hash,
+ StartedAt: time.Now(),
+ Size: FileSize(fileName),
PacketsPerService: make(map[uint16]flowCount),
cancelFunc: cancelFunc,
completed: make(chan string),
diff --git a/pcap_importer_test.go b/pcap_importer_test.go
index bda2cb2..8d03e4e 100644
--- a/pcap_importer_test.go
+++ b/pcap_importer_test.go
@@ -53,6 +53,7 @@ func TestCancelImportSession(t *testing.T) {
session := waitSessionCompletion(t, pcapImporter, sessionID)
assert.Zero(t, session.CompletedAt)
+ assert.Equal(t, 1270696, session.Size)
assert.Equal(t, 0, session.ProcessedPackets)
assert.Equal(t, 0, session.InvalidPackets)
assert.Equal(t, map[uint16]flowCount{}, session.PacketsPerService)
@@ -71,6 +72,7 @@ func TestImportNoTcpPackets(t *testing.T) {
require.NoError(t, err)
session := waitSessionCompletion(t, pcapImporter, sessionID)
+ assert.Equal(t, 228024, session.Size)
assert.Equal(t, 2000, session.ProcessedPackets)
assert.Equal(t, 2000, session.InvalidPackets)
assert.Equal(t, map[uint16]flowCount{}, session.PacketsPerService)
@@ -114,7 +116,10 @@ func checkSessionEquals(t *testing.T, wrapper *TestStorageWrapper, session Impor
var result ImportingSession
assert.NoError(t, wrapper.Storage.Find(ImportingSessions).Filter(OrderedDocument{{"_id", session.ID}}).
Context(wrapper.Context).First(&result))
+ assert.Equal(t, session.StartedAt.Unix(), result.StartedAt.Unix())
assert.Equal(t, session.CompletedAt.Unix(), result.CompletedAt.Unix())
+ session.StartedAt = time.Time{}
+ result.StartedAt = time.Time{}
session.CompletedAt = time.Time{}
result.CompletedAt = time.Time{}
session.cancelFunc = nil
diff --git a/stream_handler.go b/stream_handler.go
index a436fd5..97975fa 100644
--- a/stream_handler.go
+++ b/stream_handler.go
@@ -149,7 +149,7 @@ func (sh *StreamHandler) storageCurrentDocument() {
payload := sh.streamFlow.Hash()&uint64(0xffffffffffffff00) | uint64(len(sh.documentsIDs)) // LOL
streamID := CustomRowID(payload, sh.firstPacketSeen)
- _, err := sh.connection.Storage().Insert(ConnectionStreams).
+ if _, err := sh.connection.Storage().Insert(ConnectionStreams).
One(ConnectionStream{
ID: streamID,
ConnectionID: ZeroRowID,
@@ -159,9 +159,7 @@ func (sh *StreamHandler) storageCurrentDocument() {
BlocksTimestamps: sh.timestamps,
BlocksLoss: sh.lossBlocks,
PatternMatches: sh.patternMatches,
- })
-
- if err != nil {
+ }); err != nil {
log.WithError(err).Error("failed to insert connection stream")
} else {
sh.documentsIDs = append(sh.documentsIDs, streamID)
diff --git a/utils.go b/utils.go
index b9cdd8c..de83ecb 100644
--- a/utils.go
+++ b/utils.go
@@ -57,3 +57,19 @@ func RowIDFromHex(hex string) (RowID, error) {
rowID, err := primitive.ObjectIDFromHex(hex)
return rowID, err
}
+
+func FileExists(filename string) bool {
+ info, err := os.Stat(filename)
+ if os.IsNotExist(err) {
+ return false
+ }
+ return !info.IsDir()
+}
+
+func FileSize(filename string) int64 {
+ info, err := os.Stat(filename)
+ if os.IsNotExist(err) {
+ return -1
+ }
+ return info.Size()
+}