From 2cffb8800d857e031178e084f602a78282490937 Mon Sep 17 00:00:00 2001 From: JJ Date: Fri, 19 Jul 2024 13:56:29 -0700 Subject: remove tests --- application_context_test.go | 75 --------- application_router_test.go | 209 -------------------------- caronte_test.go | 71 --------- connection_handler_test.go | 227 ---------------------------- pcap_importer_test.go | 177 ---------------------- rules_manager_test.go | 302 ------------------------------------- storage_test.go | 258 ------------------------------- stream_handler_test.go | 359 -------------------------------------------- 8 files changed, 1678 deletions(-) delete mode 100644 application_context_test.go delete mode 100644 application_router_test.go delete mode 100644 caronte_test.go delete mode 100644 connection_handler_test.go delete mode 100644 pcap_importer_test.go delete mode 100644 rules_manager_test.go delete mode 100644 storage_test.go delete mode 100644 stream_handler_test.go diff --git a/application_context_test.go b/application_context_test.go deleted file mode 100644 index 11d1ed4..0000000 --- a/application_context_test.go +++ /dev/null @@ -1,75 +0,0 @@ -/* - * This file is part of caronte (https://github.com/eciavatta/caronte). - * Copyright (c) 2020 Emiliano Ciavatta. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, version 3. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - -package main - -import ( - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestCreateApplicationContext(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - wrapper.AddCollection(Settings) - - appContext, err := CreateApplicationContext(wrapper.Storage, "test") - assert.NoError(t, err) - assert.False(t, appContext.IsConfigured) - assert.Zero(t, appContext.Config) - assert.Len(t, appContext.Accounts, 0) - assert.Nil(t, appContext.PcapImporter) - assert.Nil(t, appContext.RulesManager) - - notificationController := NewNotificationController(appContext) - appContext.SetNotificationController(notificationController) - assert.Equal(t, notificationController, appContext.NotificationController) - - config := Config{ - ServerAddress: "10.10.10.10", - FlagRegex: "FLAG{test}", - AuthRequired: true, - } - accounts := gin.Accounts{ - "username": "password", - } - appContext.SetConfig(config) - appContext.SetAccounts(accounts) - assert.Equal(t, appContext.Config, config) - assert.Equal(t, appContext.Accounts, accounts) - assert.NotNil(t, appContext.PcapImporter) - assert.NotNil(t, appContext.RulesManager) - assert.True(t, appContext.IsConfigured) - - config.FlagRegex = "FLAG{test2}" - accounts["username"] = "password2" - appContext.SetConfig(config) - appContext.SetAccounts(accounts) - - checkAppContext, err := CreateApplicationContext(wrapper.Storage, "test") - assert.NoError(t, err) - checkAppContext.SetNotificationController(notificationController) - checkAppContext.Configure() - assert.True(t, checkAppContext.IsConfigured) - assert.Equal(t, checkAppContext.Config, config) - assert.Equal(t, checkAppContext.Accounts, accounts) - assert.NotNil(t, checkAppContext.PcapImporter) - assert.NotNil(t, checkAppContext.RulesManager) - assert.Equal(t, notificationController, appContext.NotificationController) - - wrapper.Destroy(t) -} diff --git a/application_router_test.go b/application_router_test.go deleted file mode 100644 index d4b545f..0000000 --- a/application_router_test.go +++ /dev/null @@ -1,209 +0,0 @@ -/* - * This file is part of caronte (https://github.com/eciavatta/caronte). - * Copyright (c) 2020 Emiliano Ciavatta. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, version 3. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - -package main - -import ( - "bytes" - "encoding/json" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "io" - "net/http" - "net/http/httptest" - "testing" - "time" -) - -func TestSetupApplication(t *testing.T) { - toolkit := NewRouterTestToolkit(t, false) - - settings := make(map[string]interface{}) - assert.Equal(t, http.StatusServiceUnavailable, toolkit.MakeRequest("GET", "/api/rules", nil).Code) - assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("POST", "/setup", settings).Code) - settings["config"] = Config{ServerAddress: "1.2.3.4", FlagRegex: "FLAG{test}", AuthRequired: true} - assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("POST", "/setup", settings).Code) - settings["accounts"] = gin.Accounts{"username": "password"} - assert.Equal(t, http.StatusAccepted, toolkit.MakeRequest("POST", "/setup", settings).Code) - assert.Equal(t, http.StatusNotFound, toolkit.MakeRequest("POST", "/setup", settings).Code) - - toolkit.wrapper.Destroy(t) -} - -func TestAuthRequired(t *testing.T) { - toolkit := NewRouterTestToolkit(t, true) - - assert.Equal(t, http.StatusOK, toolkit.MakeRequest("GET", "/api/rules", nil).Code) - config := toolkit.appContext.Config - config.AuthRequired = true - toolkit.appContext.SetConfig(config) - toolkit.appContext.SetAccounts(gin.Accounts{"username": "password"}) - assert.Equal(t, http.StatusUnauthorized, toolkit.MakeRequest("GET", "/api/rules", nil).Code) - - toolkit.wrapper.Destroy(t) -} - -func TestRulesApi(t *testing.T) { - toolkit := NewRouterTestToolkit(t, true) - - // AddRule - assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("POST", "/api/rules", Rule{}).Code) - assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("POST", "/api/rules", - Rule{Name: "testRule"}).Code) - assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("POST", "/api/rules", - Rule{Name: "testRule", Color: "invalidColor"}).Code) - w := toolkit.MakeRequest("POST", "/api/rules", Rule{Name: "testRule", Color: "#fff"}) - var testRuleID struct{ ID string } - assert.Equal(t, http.StatusOK, w.Code) - assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &testRuleID)) - assert.Equal(t, http.StatusUnprocessableEntity, toolkit.MakeRequest("POST", "/api/rules", - Rule{Name: "testRule", Color: "#fff"}).Code) // same name - - // UpdateRule - assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("PUT", "/api/rules/invalidID", - Rule{Name: "invalidRule", Color: "#000"}).Code) - assert.Equal(t, http.StatusNotFound, toolkit.MakeRequest("PUT", "/api/rules/000000000000000000000000", - Rule{Name: "invalidRule", Color: "#000"}).Code) - assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("PUT", "/api/rules/"+testRuleID.ID, Rule{}).Code) - assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("PUT", "/api/rules/"+testRuleID.ID, - Rule{Name: "invalidRule", Color: "invalidColor"}).Code) - w = toolkit.MakeRequest("POST", "/api/rules", Rule{Name: "testRule2", Color: "#eee"}) - var testRule2ID struct{ ID string } - assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &testRule2ID)) - assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("PUT", "/api/rules/"+testRule2ID.ID, - Rule{Name: "testRule", Color: "#fff"}).Code) // duplicate - w = toolkit.MakeRequest("PUT", "/api/rules/"+testRuleID.ID, Rule{Name: "newRule1", Color: "#ddd"}) - var testRule Rule - assert.Equal(t, http.StatusOK, w.Code) - assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &testRule)) - assert.Equal(t, "newRule1", testRule.Name) - assert.Equal(t, "#ddd", testRule.Color) - - // GetRule - assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("GET", "/api/rules/invalidID", nil).Code) - assert.Equal(t, http.StatusNotFound, toolkit.MakeRequest("GET", "/api/rules/000000000000000000000000", nil).Code) - w = toolkit.MakeRequest("GET", "/api/rules/"+testRuleID.ID, nil) - assert.Equal(t, http.StatusOK, w.Code) - assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &testRule)) - assert.Equal(t, testRuleID.ID, testRule.ID.Hex()) - assert.Equal(t, "newRule1", testRule.Name) - assert.Equal(t, "#ddd", testRule.Color) - - // GetRules - w = toolkit.MakeRequest("GET", "/api/rules", nil) - var rules []Rule - assert.Equal(t, http.StatusOK, w.Code) - assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &rules)) - assert.Len(t, rules, 4) - - toolkit.wrapper.Destroy(t) -} - -func TestPcapImporterApi(t *testing.T) { - toolkit := NewRouterTestToolkit(t, true) - - // Import pcap - assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("POST", "/api/pcap/file", nil).Code) - assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("POST", "/api/pcap/file", - gin.H{"file": "invalidPath"}).Code) - w := toolkit.MakeRequest("POST", "/api/pcap/file", gin.H{"file": "test_data/ping_pong_10000.pcap"}) - var sessionID struct{ Session string } - assert.Equal(t, http.StatusAccepted, w.Code) - assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &sessionID)) - assert.Equal(t, "369ef4b6abb6214b4ee2e0c81ecb93c49e275c26c85e30493b37727d408cf280", sessionID.Session) - assert.Equal(t, http.StatusUnprocessableEntity, toolkit.MakeRequest("POST", "/api/pcap/file", - gin.H{"file": "test_data/ping_pong_10000.pcap"}).Code) // duplicate - - // Get sessions - var sessions []ImportingSession - w = toolkit.MakeRequest("GET", "/api/pcap/sessions", nil) - assert.Equal(t, http.StatusOK, w.Code) - assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &sessions)) - assert.Len(t, sessions, 1) - assert.Equal(t, sessionID.Session, sessions[0].ID) - - // Get session - var session ImportingSession - w = toolkit.MakeRequest("GET", "/api/pcap/sessions/"+sessionID.Session, nil) - assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &session)) - assert.Equal(t, sessionID.Session, session.ID) - - // Cancel session - assert.Equal(t, http.StatusNotFound, toolkit.MakeRequest("DELETE", "/api/pcap/sessions/invalidSession", - nil).Code) - assert.Equal(t, http.StatusAccepted, toolkit.MakeRequest("DELETE", "/api/pcap/sessions/"+sessionID.Session, - nil).Code) - - time.Sleep(1 * time.Second) // wait for termination - - toolkit.wrapper.Destroy(t) -} - -type RouterTestToolkit struct { - appContext *ApplicationContext - wrapper *TestStorageWrapper - router *gin.Engine - t *testing.T -} - -func NewRouterTestToolkit(t *testing.T, withSetup bool) *RouterTestToolkit { - wrapper := NewTestStorageWrapper(t) - wrapper.AddCollection(Settings) - - appContext, err := CreateApplicationContext(wrapper.Storage, "test") - require.NoError(t, err) - gin.SetMode(gin.ReleaseMode) - notificationController := NewNotificationController(appContext) - go notificationController.Run() - resourcesController := NewResourcesController(notificationController) - router := CreateApplicationRouter(appContext, notificationController, resourcesController) - - toolkit := RouterTestToolkit{ - appContext: appContext, - wrapper: wrapper, - router: router, - t: t, - } - - if withSetup { - settings := gin.H{ - "config": Config{ServerAddress: "1.2.3.4", FlagRegex: "FLAG{test}", AuthRequired: false}, - "accounts": gin.Accounts{}, - } - toolkit.MakeRequest("POST", "/setup", settings) - } - - return &toolkit -} - -func (rtt *RouterTestToolkit) MakeRequest(method string, url string, body interface{}) *httptest.ResponseRecorder { - var r io.Reader - - if body != nil { - buf, err := json.Marshal(body) - require.NoError(rtt.t, err) - r = bytes.NewBuffer(buf) - } - - w := httptest.NewRecorder() - req, err := http.NewRequest(method, url, r) - require.NoError(rtt.t, err) - rtt.router.ServeHTTP(w, req) - - return w -} diff --git a/caronte_test.go b/caronte_test.go deleted file mode 100644 index 76c776f..0000000 --- a/caronte_test.go +++ /dev/null @@ -1,71 +0,0 @@ -/* - * This file is part of caronte (https://github.com/eciavatta/caronte). - * Copyright (c) 2020 Emiliano Ciavatta. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, version 3. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - -package main - -import ( - "context" - "fmt" - "os" - "strconv" - "testing" - "time" - - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" -) - -type TestStorageWrapper struct { - DbName string - Storage *MongoStorage - Context context.Context -} - -func NewTestStorageWrapper(t *testing.T) *TestStorageWrapper { - mongoHost, ok := os.LookupEnv("MONGO_HOST") - if !ok { - mongoHost = "localhost" - } - mongoPort, ok := os.LookupEnv("MONGO_PORT") - if !ok { - mongoPort = "27017" - } - port, err := strconv.Atoi(mongoPort) - require.NoError(t, err, "invalid port") - - dbName := fmt.Sprintf("%x", time.Now().UnixNano()) - log.WithField("dbName", dbName).Info("creating new storage") - - storage, err := NewMongoStorage(mongoHost, port, dbName) - require.NoError(t, err) - ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) - - return &TestStorageWrapper{ - DbName: dbName, - Storage: storage, - Context: ctx, - } -} - -func (tsw TestStorageWrapper) AddCollection(collectionName string) { - tsw.Storage.collections[collectionName] = tsw.Storage.client.Database(tsw.DbName).Collection(collectionName) -} - -func (tsw TestStorageWrapper) Destroy(t *testing.T) { - err := tsw.Storage.client.Disconnect(tsw.Context) - require.NoError(t, err, "failed to disconnect to database") -} diff --git a/connection_handler_test.go b/connection_handler_test.go deleted file mode 100644 index 942310a..0000000 --- a/connection_handler_test.go +++ /dev/null @@ -1,227 +0,0 @@ -/* - * This file is part of caronte (https://github.com/eciavatta/caronte). - * Copyright (c) 2020 Emiliano Ciavatta. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, version 3. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - -package main - -import ( - "context" - "encoding/binary" - "github.com/flier/gohs/hyperscan" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" - "github.com/google/gopacket/tcpassembly" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "math/rand" - "net" - "testing" - "time" -) - -func TestTakeReleaseScanners(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - serverNet := ParseIPNet(testDstIP) - ruleManager := TestRulesManager{ - databaseUpdated: make(chan RulesDatabase), - } - - database, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0)) - require.NoError(t, err) - - factory := NewBiDirectionalStreamFactory(wrapper.Storage, *serverNet, &ruleManager) - version := NewRowID() - ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} - time.Sleep(10 * time.Millisecond) - - n := 1000 - for i := 0; i < n; i++ { - scanner := factory.takeScanner() - assert.Equal(t, scanner.version, version) - - if i%50 == 0 { - version = NewRowID() - ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} - time.Sleep(10 * time.Millisecond) - } - factory.releaseScanner(scanner) - } - assert.Len(t, factory.scanners, 1) - - scanners := make([]Scanner, n) - for i := 0; i < n; i++ { - scanners[i] = factory.takeScanner() - assert.Equal(t, scanners[i].version, version) - } - for i := 0; i < n; i++ { - factory.releaseScanner(scanners[i]) - } - assert.Len(t, factory.scanners, n) - - version = NewRowID() - ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} - time.Sleep(10 * time.Millisecond) - - for i := 0; i < n; i++ { - scanners[i] = factory.takeScanner() - assert.Equal(t, scanners[i].version, version) - factory.releaseScanner(scanners[i]) - } - - close(ruleManager.DatabaseUpdateChannel()) - wrapper.Destroy(t) -} - -func TestConnectionFactory(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - wrapper.AddCollection(Connections) - wrapper.AddCollection(ConnectionStreams) - - ruleManager := TestRulesManager{ - databaseUpdated: make(chan RulesDatabase), - } - - clientIP := layers.NewIPEndpoint(net.ParseIP(testSrcIP)) - serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP)) - serverPort := layers.NewTCPPortEndpoint(dstPort) - clientServerNetFlow, err := gopacket.FlowFromEndpoints(clientIP, serverIP) - require.NoError(t, err) - serverClientNetFlow, err := gopacket.FlowFromEndpoints(serverIP, clientIP) - require.NoError(t, err) - - database, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0)) - require.NoError(t, err) - - factory := NewBiDirectionalStreamFactory(wrapper.Storage, *ParseIPNet(testDstIP), &ruleManager) - version := NewRowID() - ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} - time.Sleep(10 * time.Millisecond) - - testInteraction := func(netFlow gopacket.Flow, transportFlow gopacket.Flow, otherSeenChan chan time.Time, - completed chan bool) { - - time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond) - stream := factory.New(netFlow, transportFlow) - seen := time.Now() - stream.Reassembled([]tcpassembly.Reassembly{{[]byte{}, 0, true, true, seen}}) - stream.ReassemblyComplete() - - var startedAt, closedAt time.Time - if netFlow == serverClientNetFlow { - otherSeenChan <- seen - return - } - - otherSeen, ok := <-otherSeenChan - require.True(t, ok) - - if seen.Before(otherSeen) { - startedAt = seen - closedAt = otherSeen - } else { - startedAt = otherSeen - closedAt = seen - } - close(otherSeenChan) - - var result Connection - connectionFlow := StreamFlow{netFlow.Src(), netFlow.Dst(), transportFlow.Src(), transportFlow.Dst()} - connectionID := CustomRowID(connectionFlow.Hash(), startedAt) - op := wrapper.Storage.Find(Connections).Context(wrapper.Context) - err := op.Filter(OrderedDocument{{"_id", connectionID}}).First(&result) - require.NoError(t, err) - - assert.NotNil(t, result) - assert.Equal(t, CustomRowID(connectionFlow.Hash(), result.StartedAt), result.ID) - assert.Equal(t, netFlow.Src().String(), result.SourceIP) - assert.Equal(t, netFlow.Dst().String(), result.DestinationIP) - assert.Equal(t, binary.BigEndian.Uint16(transportFlow.Src().Raw()), result.SourcePort) - assert.Equal(t, binary.BigEndian.Uint16(transportFlow.Dst().Raw()), result.DestinationPort) - assert.Equal(t, startedAt.Unix(), result.StartedAt.Unix()) - assert.Equal(t, closedAt.Unix(), result.ClosedAt.Unix()) - - completed <- true - } - - completed := make(chan bool) - n := 1000 - - for port := 40000; port < 40000+n; port++ { - clientPort := layers.NewTCPPortEndpoint(layers.TCPPort(port)) - clientServerTransportFlow, err := gopacket.FlowFromEndpoints(clientPort, serverPort) - require.NoError(t, err) - serverClientTransportFlow, err := gopacket.FlowFromEndpoints(serverPort, clientPort) - require.NoError(t, err) - - otherSeenChan := make(chan time.Time) - go testInteraction(clientServerNetFlow, clientServerTransportFlow, otherSeenChan, completed) - go testInteraction(serverClientNetFlow, serverClientTransportFlow, otherSeenChan, completed) - } - - timeout := time.Tick(10 * time.Second) - for i := 0; i < n; i++ { - select { - case <-completed: - continue - case <-timeout: - t.Fatal("timeout") - } - } - - assert.Len(t, factory.connections, 0) - - close(ruleManager.DatabaseUpdateChannel()) - wrapper.Destroy(t) -} - -type TestRulesManager struct { - databaseUpdated chan RulesDatabase -} - -func (rm TestRulesManager) LoadRules() error { - return nil -} - -func (rm TestRulesManager) AddRule(_ context.Context, _ Rule) (RowID, error) { - return RowID{}, nil -} - -func (rm TestRulesManager) DeleteRule(_ context.Context, _ RowID) error { - return nil -} - -func (rm TestRulesManager) GetRule(_ RowID) (Rule, bool) { - return Rule{}, false -} - -func (rm TestRulesManager) UpdateRule(_ context.Context, _ RowID, _ Rule) (bool, error) { - return false, nil -} - -func (rm TestRulesManager) GetRules() []Rule { - return nil -} - -func (rm TestRulesManager) SetFlag(_ context.Context, _ string) error { - return nil -} - -func (rm TestRulesManager) FillWithMatchedRules(_ *Connection, _ map[uint][]PatternSlice, _ map[uint][]PatternSlice) { -} - -func (rm TestRulesManager) DatabaseUpdateChannel() chan RulesDatabase { - return rm.databaseUpdated -} diff --git a/pcap_importer_test.go b/pcap_importer_test.go deleted file mode 100644 index 4761927..0000000 --- a/pcap_importer_test.go +++ /dev/null @@ -1,177 +0,0 @@ -/* - * This file is part of caronte (https://github.com/eciavatta/caronte). - * Copyright (c) 2020 Emiliano Ciavatta. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, version 3. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - -package main - -import ( - "bufio" - "fmt" - "github.com/google/gopacket" - "github.com/google/gopacket/tcpassembly" - "github.com/google/gopacket/tcpassembly/tcpreader" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "os" - "sync" - "testing" - "time" -) - -func TestImportPcap(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - pcapImporter := newTestPcapImporter(wrapper, "172.17.0.3") - - pcapImporter.releaseAssembler(pcapImporter.takeAssembler()) - - fileName := copyToProcessing(t, "ping_pong_10000.pcap") - sessionID, err := pcapImporter.ImportPcap(fileName, false) - require.NoError(t, err) - - duplicatePcapFileName := copyToProcessing(t, "ping_pong_10000.pcap") - duplicateSessionID, err := pcapImporter.ImportPcap(duplicatePcapFileName, false) - require.Error(t, err) - assert.Equal(t, sessionID, duplicateSessionID) - assert.Error(t, os.Remove(ProcessingPcapsBasePath+duplicatePcapFileName)) - - _, isPresent := pcapImporter.GetSession("invalid") - assert.False(t, isPresent) - - session := waitSessionCompletion(t, pcapImporter, sessionID) - assert.Equal(t, 15008, session.ProcessedPackets) - assert.Equal(t, 0, session.InvalidPackets) - assert.Equal(t, map[uint16]flowCount{9999: {10004, 5004}}, session.PacketsPerService) - assert.Zero(t, session.ImportingError) - - checkSessionEquals(t, wrapper, session) - - assert.Error(t, os.Remove(ProcessingPcapsBasePath+fileName)) - assert.NoError(t, os.Remove(PcapsBasePath+session.ID+".pcap")) - - wrapper.Destroy(t) -} - -func TestCancelImportSession(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - pcapImporter := newTestPcapImporter(wrapper, "172.17.0.3") - - fileName := copyToProcessing(t, "ping_pong_10000.pcap") - sessionID, err := pcapImporter.ImportPcap(fileName, false) - require.NoError(t, err) - - assert.False(t, pcapImporter.CancelSession("invalid")) - assert.True(t, pcapImporter.CancelSession(sessionID)) - - session := waitSessionCompletion(t, pcapImporter, sessionID) - assert.Zero(t, session.CompletedAt) - assert.Equal(t, int64(1270696), session.Size) - // assert.Equal(t, 0, session.ProcessedPackets) // TODO: investigate - assert.Equal(t, 0, session.InvalidPackets) - // assert.Equal(t, map[uint16]flowCount{}, session.PacketsPerService) - assert.NotZero(t, session.ImportingError) - - checkSessionEquals(t, wrapper, session) - - assert.Error(t, os.Remove(ProcessingPcapsBasePath+fileName)) - assert.Error(t, os.Remove(PcapsBasePath+sessionID+".pcap")) - - wrapper.Destroy(t) -} - -func TestImportNoTcpPackets(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - pcapImporter := newTestPcapImporter(wrapper, "172.17.0.4") - - fileName := copyToProcessing(t, "icmp.pcap") - sessionID, err := pcapImporter.ImportPcap(fileName, false) - require.NoError(t, err) - - session := waitSessionCompletion(t, pcapImporter, sessionID) - assert.Equal(t, int64(228024), session.Size) - assert.Equal(t, 2000, session.ProcessedPackets) - assert.Equal(t, 2000, session.InvalidPackets) - assert.Equal(t, map[uint16]flowCount{}, session.PacketsPerService) - assert.Zero(t, session.ImportingError) - - checkSessionEquals(t, wrapper, session) - - assert.Error(t, os.Remove(ProcessingPcapsBasePath+fileName)) - assert.NoError(t, os.Remove(PcapsBasePath+sessionID+".pcap")) - - wrapper.Destroy(t) -} - -func newTestPcapImporter(wrapper *TestStorageWrapper, serverAddress string) *PcapImporter { - wrapper.AddCollection(ImportingSessions) - - streamPool := tcpassembly.NewStreamPool(&testStreamFactory{}) - - return &PcapImporter{ - storage: wrapper.Storage, - streamPool: streamPool, - assemblers: make([]*tcpassembly.Assembler, 0, initialAssemblerPoolSize), - sessions: make(map[string]ImportingSession), - mAssemblers: sync.Mutex{}, - mSessions: sync.Mutex{}, - serverNet: *ParseIPNet(serverAddress), - notificationController: NewNotificationController(nil), - } -} - -func waitSessionCompletion(t *testing.T, pcapImporter *PcapImporter, sessionID string) ImportingSession { - session, isPresent := pcapImporter.GetSession(sessionID) - require.True(t, isPresent) - <-session.completed - - session, isPresent = pcapImporter.GetSession(sessionID) - assert.True(t, isPresent) - assert.Equal(t, sessionID, session.ID) - - return session -} - -func checkSessionEquals(t *testing.T, wrapper *TestStorageWrapper, session ImportingSession) { - var result ImportingSession - assert.NoError(t, wrapper.Storage.Find(ImportingSessions).Filter(OrderedDocument{{"_id", session.ID}}). - Context(wrapper.Context).First(&result)) - assert.Equal(t, session.StartedAt.Unix(), result.StartedAt.Unix()) - assert.Equal(t, session.CompletedAt.Unix(), result.CompletedAt.Unix()) - session.StartedAt = time.Time{} - result.StartedAt = time.Time{} - session.CompletedAt = time.Time{} - result.CompletedAt = time.Time{} - session.cancelFunc = nil - session.completed = nil - assert.Equal(t, session, result) -} - -func copyToProcessing(t *testing.T, fileName string) string { - newFile := fmt.Sprintf("test-%v-%s", time.Now().UnixNano(), fileName) - require.NoError(t, CopyFile(ProcessingPcapsBasePath+newFile, "test_data/"+fileName)) - return newFile -} - -type testStreamFactory struct { -} - -func (sf *testStreamFactory) New(_, _ gopacket.Flow) tcpassembly.Stream { - reader := tcpreader.NewReaderStream() - go func() { - buffer := bufio.NewReader(&reader) - tcpreader.DiscardBytesToEOF(buffer) - }() - return &reader -} diff --git a/rules_manager_test.go b/rules_manager_test.go deleted file mode 100644 index d3f6d20..0000000 --- a/rules_manager_test.go +++ /dev/null @@ -1,302 +0,0 @@ -/* - * This file is part of caronte (https://github.com/eciavatta/caronte). - * Copyright (c) 2020 Emiliano Ciavatta. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, version 3. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - -package main - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestAddAndGetAllRules(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - wrapper.AddCollection(Rules) - - rulesManager, err := LoadRulesManager(wrapper.Storage, "FLAG{test}") - require.NoError(t, err) - impl := rulesManager.(*rulesManagerImpl) - checkVersion(t, rulesManager, impl.rulesByName["flag_out"].ID) - checkVersion(t, rulesManager, impl.rulesByName["flag_in"].ID) - emptyRule := Rule{Name: "empty", Color: "#fff", Enabled: true} - emptyID, err := rulesManager.AddRule(wrapper.Context, emptyRule) - assert.NoError(t, err) - assert.NotNil(t, emptyID) - checkVersion(t, rulesManager, emptyID) - - duplicateRule, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#eee"}) - assert.Error(t, err) - assert.Zero(t, duplicateRule) - - invalidPattern, err := rulesManager.AddRule(wrapper.Context, Rule{ - Name: "invalidPattern", - Color: "#eee", - Patterns: []Pattern{ - { - Regex: "invalid)", - }, - }, - }) - assert.Error(t, err) - assert.Zero(t, invalidPattern) - - rule1 := Rule{ - Name: "rule1", - Color: "#eee", - Patterns: []Pattern{ - { - Regex: "pattern1", - Flags: RegexFlags{ - Caseless: true, - DotAll: true, - MultiLine: true, - Utf8Mode: true, - UnicodeProperty: true, - }, - MinOccurrences: 1, - MaxOccurrences: 3, - Direction: DirectionBoth, - }}, - Enabled: true, - } - rule1ID, err := rulesManager.AddRule(wrapper.Context, rule1) - assert.NoError(t, err) - assert.NotNil(t, rule1ID) - checkVersion(t, rulesManager, rule1ID) - - rule2 := Rule{ - Name: "rule2", - Color: "#ddd", - Patterns: []Pattern{ - {Regex: "pattern1"}, - {Regex: "pattern2"}, - }, - Enabled: true, - } - rule2ID, err := rulesManager.AddRule(wrapper.Context, rule2) - assert.NoError(t, err) - assert.NotNil(t, rule2ID) - checkVersion(t, rulesManager, rule2ID) - - rule3 := Rule{ - Name: "rule3", - Color: "#ccc", - Patterns: []Pattern{ - {Regex: "pattern2"}, - {Regex: "pattern3"}, - }, - Enabled: true, - } - rule3ID, err := rulesManager.AddRule(wrapper.Context, rule3) - assert.NoError(t, err) - assert.NotNil(t, rule3ID) - checkVersion(t, rulesManager, rule3ID) - - checkRule := func(expected Rule, patternIDs []int) { - var rule Rule - err := wrapper.Storage.Find(Rules).Context(wrapper.Context). - Filter(OrderedDocument{{"_id", expected.ID}}).First(&rule) - require.NoError(t, err) - - for i, id := range patternIDs { - rule.Patterns[i].internalID = uint(id) - } - assert.Equal(t, expected, rule) - assert.Equal(t, expected, impl.rules[expected.ID]) - assert.Equal(t, expected, impl.rulesByName[expected.Name]) - } - - assert.Len(t, impl.rules, 6) - assert.Len(t, impl.rulesByName, 6) - assert.Len(t, impl.patterns, 5) - assert.Len(t, impl.patternsIds, 5) - - emptyRule.ID = emptyID - rule1.ID = rule1ID - rule2.ID = rule2ID - rule3.ID = rule3ID - - checkRule(emptyRule, []int{}) - checkRule(rule1, []int{1}) - checkRule(rule2, []int{2, 3}) - checkRule(rule3, []int{3, 4}) - - assert.Len(t, rulesManager.GetRules(), 6) - assert.ElementsMatch(t, []Rule{impl.rulesByName["flag_out"], impl.rulesByName["flag_in"], emptyRule, - rule1, rule2, rule3}, rulesManager.GetRules()) - - wrapper.Destroy(t) -} - -func TestLoadAndUpdateRules(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - wrapper.AddCollection(Rules) - - expectedIds := []RowID{NewRowID(), NewRowID(), NewRowID(), NewRowID()} - rules := []interface{}{ - Rule{ID: expectedIds[0], Name: "rule1", Color: "#fff", Patterns: []Pattern{ - {Regex: "/pattern1/", Flags: RegexFlags{Caseless: true}, Direction: DirectionToClient, internalID: 0}, - }}, - Rule{ID: expectedIds[1], Name: "rule2", Color: "#eee", Patterns: []Pattern{ - {Regex: "/pattern2/", MinOccurrences: 1, MaxOccurrences: 3, Direction: DirectionToServer, internalID: 1}, - }}, - Rule{ID: expectedIds[2], Name: "rule3", Color: "#ddd", Patterns: []Pattern{ - {Regex: "/pattern2/", Direction: DirectionBoth, internalID: 1}, - {Regex: "/pattern3/", Flags: RegexFlags{MultiLine: true}, internalID: 2}, - }}, - Rule{ID: expectedIds[3], Name: "rule4", Color: "#ccc", Patterns: []Pattern{ - {Regex: "/pattern3/", internalID: 3}, - }}, - } - ids, err := wrapper.Storage.Insert(Rules).Context(wrapper.Context).Many(rules) - require.NoError(t, err) - assert.ElementsMatch(t, expectedIds, ids) - - rulesManager, err := LoadRulesManager(wrapper.Storage, "FLAG{nope}") - require.NoError(t, err) - - rule, isPresent := rulesManager.GetRule(NewRowID()) - assert.Zero(t, rule) - assert.False(t, isPresent) - - for _, objRule := range rules { - expected := objRule.(Rule) - rule, isPresent := rulesManager.GetRule(expected.ID) - assert.True(t, isPresent) - assert.Equal(t, expected, rule) - } - - updated, err := rulesManager.UpdateRule(wrapper.Context, NewRowID(), Rule{}) - assert.False(t, updated) - assert.NoError(t, err) - - updated, err = rulesManager.UpdateRule(wrapper.Context, expectedIds[0], Rule{Name: "rule2", Color: "#fff"}) - assert.False(t, updated) - assert.Error(t, err) - - for _, objRule := range rules { - expected := objRule.(Rule) - expected.Name = expected.ID.Hex() - expected.Color = "#000" - updated, err := rulesManager.UpdateRule(wrapper.Context, expected.ID, expected) - assert.True(t, updated) - assert.NoError(t, err) - - rule, isPresent := rulesManager.GetRule(expected.ID) - assert.True(t, isPresent) - assert.Equal(t, expected, rule) - } - - wrapper.Destroy(t) -} - -func TestFillWithMatchedRules(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - wrapper.AddCollection(Rules) - - rulesManager, err := LoadRulesManager(wrapper.Storage, "FLAG{test}") - require.NoError(t, err) - impl := rulesManager.(*rulesManagerImpl) - checkVersion(t, rulesManager, impl.rulesByName["flag_out"].ID) - checkVersion(t, rulesManager, impl.rulesByName["flag_in"].ID) - - emptyRule, err := rulesManager.AddRule(wrapper.Context, Rule{Name: "empty", Color: "#fff"}) - require.NoError(t, err) - checkVersion(t, rulesManager, emptyRule) - - conn := &Connection{} - rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{}, map[uint][]PatternSlice{}) - assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules) - - filterRule, err := rulesManager.AddRule(wrapper.Context, Rule{ - Name: "filter", - Color: "#fff", - Filter: Filter{ - ServicePort: 80, - ClientAddress: "10.10.10.10", - ClientPort: 60000, - MinDuration: 2000, - MaxDuration: 4000, - MinBytes: 64, - MaxBytes: 64, - }, - }) - require.NoError(t, err) - checkVersion(t, rulesManager, filterRule) - conn = &Connection{ - SourceIP: "10.10.10.10", - SourcePort: 60000, - DestinationPort: 80, - ClientBytes: 32, - ServerBytes: 32, - StartedAt: time.Now(), - ClosedAt: time.Now().Add(3 * time.Second), - } - rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{}, map[uint][]PatternSlice{}) - assert.ElementsMatch(t, []RowID{emptyRule, filterRule}, conn.MatchedRules) - - patternRule, err := rulesManager.AddRule(wrapper.Context, Rule{ - Name: "pattern", - Color: "#fff", - Patterns: []Pattern{ - {Regex: "pattern1", Direction: DirectionToClient, MinOccurrences: 1}, - {Regex: "pattern2", Direction: DirectionToServer, MaxOccurrences: 2}, - {Regex: "pattern3", Direction: DirectionBoth, MinOccurrences: 2, MaxOccurrences: 2}, - }, - }) - require.NoError(t, err) - checkVersion(t, rulesManager, patternRule) - conn = &Connection{} - rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0}, {0, 0}}, 3: {{0, 0}}}, - map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0}}}) - assert.ElementsMatch(t, []RowID{emptyRule, patternRule}, conn.MatchedRules) - - rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0}, {0, 0}}}, - map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0}, {0, 0}}}) - assert.ElementsMatch(t, []RowID{emptyRule, patternRule}, conn.MatchedRules) - - rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0}, {0, 0}}, 3: {{0, 0}, {0, 0}}}, - map[uint][]PatternSlice{1: {{0, 0}}}) - assert.ElementsMatch(t, []RowID{emptyRule, patternRule}, conn.MatchedRules) - - rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0}, {0, 0}}, 3: {{0, 0}}}, - map[uint][]PatternSlice{3: {{0, 0}}}) - assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules) - - rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0}, {0, 0}, {0, 0}}, 3: {{0, 0}}}, - map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0}}}) - assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules) - - rulesManager.FillWithMatchedRules(conn, map[uint][]PatternSlice{2: {{0, 0}, {0, 0}}, 3: {{0, 0}}}, - map[uint][]PatternSlice{1: {{0, 0}}, 3: {{0, 0}, {0, 0}}}) - assert.ElementsMatch(t, []RowID{emptyRule}, conn.MatchedRules) - - wrapper.Destroy(t) -} - -func checkVersion(t *testing.T, rulesManager RulesManager, id RowID) { - timeout := time.Tick(1 * time.Second) - - select { - case database := <-rulesManager.DatabaseUpdateChannel(): - assert.Equal(t, id, database.version) - case <-timeout: - t.Fatal("timeout") - } -} diff --git a/storage_test.go b/storage_test.go deleted file mode 100644 index dd91e97..0000000 --- a/storage_test.go +++ /dev/null @@ -1,258 +0,0 @@ -/* - * This file is part of caronte (https://github.com/eciavatta/caronte). - * Copyright (c) 2020 Emiliano Ciavatta. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, version 3. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - -package main - -import ( - "github.com/stretchr/testify/assert" - "go.mongodb.org/mongo-driver/bson/primitive" - "testing" - "time" -) - -type a struct { - ID primitive.ObjectID `bson:"_id,omitempty"` - A string `bson:"a,omitempty"` - B int `bson:"b,omitempty"` - C time.Time `bson:"c,omitempty"` - D map[string]b `bson:"d"` - E []b `bson:"e,omitempty"` -} - -type b struct { - A string `bson:"a,omitempty"` - B int `bson:"b,omitempty"` -} - -func TestOperationOnInvalidCollection(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - - simpleDoc := UnorderedDocument{"key": "a", "value": 0} - insertOp := wrapper.Storage.Insert("invalid_collection").Context(wrapper.Context) - insertedID, err := insertOp.One(simpleDoc) - assert.Nil(t, insertedID) - assert.Error(t, err) - - insertedIDs, err := insertOp.Many([]interface{}{simpleDoc}) - assert.Nil(t, insertedIDs) - assert.Error(t, err) - - updateOp := wrapper.Storage.Update("invalid_collection").Context(wrapper.Context) - isUpdated, err := updateOp.One(simpleDoc) - assert.False(t, isUpdated) - assert.Error(t, err) - - updated, err := updateOp.Many(simpleDoc) - assert.Zero(t, updated) - assert.Error(t, err) - - findOp := wrapper.Storage.Find("invalid_collection").Context(wrapper.Context) - var result interface{} - err = findOp.First(&result) - assert.Nil(t, result) - assert.Error(t, err) - - var results interface{} - err = findOp.All(&result) - assert.Nil(t, results) - assert.Error(t, err) - - wrapper.Destroy(t) -} - -func TestSimpleInsertAndFind(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - collectionName := "simple_insert_find" - wrapper.AddCollection(collectionName) - - insertOp := wrapper.Storage.Insert(collectionName).Context(wrapper.Context) - simpleDocA := UnorderedDocument{"key": "a"} - idA, err := insertOp.One(simpleDocA) - assert.Len(t, idA, 12) - assert.Nil(t, err) - - simpleDocB := UnorderedDocument{"_id": "idb", "key": "b"} - idB, err := insertOp.One(simpleDocB) - assert.Equal(t, "idb", idB) - assert.Nil(t, err) - - var result UnorderedDocument - findOp := wrapper.Storage.Find(collectionName).Context(wrapper.Context) - err = findOp.Filter(OrderedDocument{{"key", "a"}}).First(&result) - assert.Nil(t, err) - assert.Equal(t, idA, result["_id"]) - assert.Equal(t, simpleDocA["key"], result["key"]) - - err = findOp.Filter(OrderedDocument{{"_id", idB}}).First(&result) - assert.Nil(t, err) - assert.Equal(t, idB, result["_id"]) - assert.Equal(t, simpleDocB["key"], result["key"]) - - wrapper.Destroy(t) -} - -func TestSimpleInsertManyAndFindMany(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - collectionName := "simple_insert_many_find_many" - wrapper.AddCollection(collectionName) - - insertOp := wrapper.Storage.Insert(collectionName).Context(wrapper.Context) - simpleDocs := []interface{}{ - UnorderedDocument{"key": "a"}, - UnorderedDocument{"_id": "idb", "key": "b"}, - UnorderedDocument{"key": "c"}, - } - ids, err := insertOp.Many(simpleDocs) - assert.Nil(t, err) - assert.Len(t, ids, 3) - assert.Equal(t, "idb", ids[1]) - - var results []UnorderedDocument - findOp := wrapper.Storage.Find(collectionName).Context(wrapper.Context) - err = findOp.Sort("key", false).All(&results) // test sort ascending - assert.Nil(t, err) - assert.Len(t, results, 3) - assert.Equal(t, "c", results[0]["key"]) - assert.Equal(t, "b", results[1]["key"]) - assert.Equal(t, "a", results[2]["key"]) - - err = findOp.Sort("key", true).All(&results) // test sort descending - assert.Nil(t, err) - assert.Len(t, results, 3) - assert.Equal(t, "c", results[2]["key"]) - assert.Equal(t, "b", results[1]["key"]) - assert.Equal(t, "a", results[0]["key"]) - - err = findOp.Filter(OrderedDocument{{"key", OrderedDocument{{"$gte", "b"}}}}). - Sort("key", true).All(&results) // test filter - assert.Nil(t, err) - assert.Len(t, results, 2) - assert.Equal(t, "b", results[0]["key"]) - assert.Equal(t, "c", results[1]["key"]) - - wrapper.Destroy(t) -} - -func TestSimpleUpdateOneUpdateMany(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - collectionName := "simple_update_one_update_many" - wrapper.AddCollection(collectionName) - - insertOp := wrapper.Storage.Insert(collectionName).Context(wrapper.Context) - simpleDocs := []interface{}{ - UnorderedDocument{"_id": "ida", "key": "a"}, - UnorderedDocument{"key": "b"}, - UnorderedDocument{"key": "c"}, - } - _, err := insertOp.Many(simpleDocs) - assert.Nil(t, err) - - updateOp := wrapper.Storage.Update(collectionName).Context(wrapper.Context) - isUpdated, err := updateOp.Filter(OrderedDocument{{"_id", "ida"}}). - One(OrderedDocument{{"key", "aa"}}) - assert.Nil(t, err) - assert.True(t, isUpdated) - - updated, err := updateOp.Filter(OrderedDocument{{"key", OrderedDocument{{"$gte", "b"}}}}). - Many(OrderedDocument{{"key", "bb"}}) - assert.Nil(t, err) - assert.Equal(t, int64(2), updated) - - var upsertID interface{} - isUpdated, err = updateOp.Upsert(&upsertID).Filter(OrderedDocument{{"key", "d"}}). - One(OrderedDocument{{"key", "d"}}) - assert.Nil(t, err) - assert.False(t, isUpdated) - assert.NotNil(t, upsertID) - - var results []UnorderedDocument - findOp := wrapper.Storage.Find(collectionName).Context(wrapper.Context) - err = findOp.Sort("key", true).All(&results) // test sort ascending - assert.Nil(t, err) - assert.Len(t, results, 4) - assert.Equal(t, "aa", results[0]["key"]) - assert.Equal(t, "bb", results[1]["key"]) - assert.Equal(t, "bb", results[2]["key"]) - assert.Equal(t, "d", results[3]["key"]) - - wrapper.Destroy(t) -} - -func TestComplexInsertManyFindMany(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - collectionName := "complex_insert_many_find_many" - wrapper.AddCollection(collectionName) - - testTime := time.Now() - oid1, err := primitive.ObjectIDFromHex("ffffffffffffffffffffffff") - assert.Nil(t, err) - - docs := []interface{}{ - a{ - A: "test0", - B: 0, - C: testTime, - D: map[string]b{ - "first": {A: "0", B: 0}, - "second": {A: "1", B: 1}, - }, - E: []b{ - {A: "0", B: 0}, {A: "1", B: 0}, - }, - }, - a{ - ID: oid1, - A: "test1", - B: 1, - C: testTime, - D: map[string]b{}, - E: []b{}, - }, - a{}, - } - - ids, err := wrapper.Storage.Insert(collectionName).Context(wrapper.Context).Many(docs) - assert.Nil(t, err) - assert.Len(t, ids, 3) - assert.Equal(t, ids[1], oid1) - - var results []a - err = wrapper.Storage.Find(collectionName).Context(wrapper.Context).All(&results) - assert.Nil(t, err) - assert.Len(t, results, 3) - doc0, doc1, doc2 := docs[0].(a), docs[1].(a), docs[2].(a) - assert.Equal(t, ids[0], results[0].ID) - assert.Equal(t, doc1.ID, results[1].ID) - assert.Equal(t, ids[2], results[2].ID) - assert.Equal(t, doc0.A, results[0].A) - assert.Equal(t, doc1.A, results[1].A) - assert.Equal(t, doc2.A, results[2].A) - assert.Equal(t, doc0.B, results[0].B) - assert.Equal(t, doc1.B, results[1].B) - assert.Equal(t, doc2.B, results[2].B) - assert.Equal(t, doc0.C.Unix(), results[0].C.Unix()) - assert.Equal(t, doc1.C.Unix(), results[1].C.Unix()) - assert.Equal(t, doc2.C.Unix(), results[2].C.Unix()) - assert.Equal(t, doc0.D, results[0].D) - assert.Equal(t, doc1.D, results[1].D) - assert.Equal(t, doc2.D, results[2].D) - assert.Equal(t, doc0.E, results[0].E) - assert.Nil(t, results[1].E) - assert.Nil(t, results[2].E) - - wrapper.Destroy(t) -} diff --git a/stream_handler_test.go b/stream_handler_test.go deleted file mode 100644 index b185483..0000000 --- a/stream_handler_test.go +++ /dev/null @@ -1,359 +0,0 @@ -/* - * This file is part of caronte (https://github.com/eciavatta/caronte). - * Copyright (c) 2020 Emiliano Ciavatta. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, version 3. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - -package main - -import ( - "context" - "crypto/rand" - "net" - "testing" - "time" - - "github.com/flier/gohs/hyperscan" - "github.com/google/gopacket/layers" - "github.com/google/gopacket/tcpassembly" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -const testSrcIP = "10.10.10.100" -const testDstIP = "10.10.10.1" -const srcPort = 44444 -const dstPort = 8080 - -func TestReassemblingEmptyStream(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - wrapper.AddCollection(ConnectionStreams) - patterns, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0)) - require.NoError(t, err) - scratch, err := hyperscan.NewScratch(patterns) - require.NoError(t, err) - streamHandler := createTestStreamHandler(wrapper, patterns, scratch) - - streamHandler.Reassembled([]tcpassembly.Reassembly{{ - Bytes: []byte{}, - Skip: 0, - Start: true, - End: true, - }}) - assert.Len(t, streamHandler.indexes, 0, "indexes") - assert.Len(t, streamHandler.timestamps, 0, "timestamps") - assert.Len(t, streamHandler.lossBlocks, 0) - assert.Zero(t, streamHandler.currentIndex) - assert.Zero(t, streamHandler.firstPacketSeen) - assert.Zero(t, streamHandler.lastPacketSeen) - assert.Len(t, streamHandler.documentsIDs, 0) - assert.Zero(t, streamHandler.streamLength) - assert.Len(t, streamHandler.patternMatches, 0) - - completed := false - streamHandler.connection.(*testConnectionHandler).onComplete = func(handler *StreamHandler) { - completed = true - } - streamHandler.ReassemblyComplete() - assert.Equal(t, true, completed) - - err = scratch.Free() - require.NoError(t, err, "free scratch") - err = patterns.Close() - require.NoError(t, err, "close stream database") - wrapper.Destroy(t) -} - -func TestReassemblingSingleDocument(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - wrapper.AddCollection(ConnectionStreams) - patterns, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/impossible_to_match/", 0)) - require.NoError(t, err) - scratch, err := hyperscan.NewScratch(patterns) - require.NoError(t, err) - streamHandler := createTestStreamHandler(wrapper, patterns, scratch) - - payloadLen := 256 - firstTime := time.Unix(0, 0) - middleTime := time.Unix(10, 0) - lastTime := time.Unix(20, 0) - data := make([]byte, MaxDocumentSize) - rand.Read(data) - reassembles := make([]tcpassembly.Reassembly, MaxDocumentSize/payloadLen) - indexes := make([]int, MaxDocumentSize/payloadLen) - timestamps := make([]time.Time, MaxDocumentSize/payloadLen) - lossBlocks := make([]bool, MaxDocumentSize/payloadLen) - for i := 0; i < len(reassembles); i++ { - var seen time.Time - if i == 0 { - seen = firstTime - } else if i == len(reassembles)-1 { - seen = lastTime - } else { - seen = middleTime - } - - reassembles[i] = tcpassembly.Reassembly{ - Bytes: data[i*payloadLen : (i+1)*payloadLen], - Skip: 0, - Start: i == 0, - End: i == len(reassembles)-1, - Seen: seen, - } - indexes[i] = i * payloadLen - timestamps[i] = seen - } - - var results []ConnectionStream - - streamHandler.Reassembled(reassembles) - err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results) - require.NoError(t, err) - assert.Len(t, results, 0) - - completed := false - streamHandler.connection.(*testConnectionHandler).onComplete = func(handler *StreamHandler) { - completed = true - } - streamHandler.ReassemblyComplete() - - err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results) - require.NoError(t, err) - assert.Len(t, results, 1) - assert.Equal(t, firstTime.Unix(), results[0].ID.Timestamp().Unix()) - assert.Zero(t, results[0].ConnectionID) - assert.Equal(t, 0, results[0].DocumentIndex) - assert.Equal(t, data, results[0].Payload) - assert.Equal(t, indexes, results[0].BlocksIndexes) - assert.Len(t, results[0].BlocksTimestamps, len(timestamps)) // should be compared one by one - assert.Equal(t, lossBlocks, results[0].BlocksLoss) - assert.Len(t, results[0].PatternMatches, 0) - - assert.Equal(t, len(data), streamHandler.currentIndex) - assert.Equal(t, firstTime, streamHandler.firstPacketSeen) - assert.Equal(t, lastTime, streamHandler.lastPacketSeen) - assert.Len(t, streamHandler.documentsIDs, 1) - assert.Equal(t, len(data), streamHandler.streamLength) - assert.Len(t, streamHandler.patternMatches, 0) - - assert.Equal(t, true, completed, "completed") - - err = scratch.Free() - require.NoError(t, err, "free scratch") - err = patterns.Close() - require.NoError(t, err, "close stream database") - wrapper.Destroy(t) -} - -func TestReassemblingMultipleDocuments(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - wrapper.AddCollection(ConnectionStreams) - patterns, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/impossible_to_match/", 0)) - require.NoError(t, err) - scratch, err := hyperscan.NewScratch(patterns) - require.NoError(t, err) - streamHandler := createTestStreamHandler(wrapper, patterns, scratch) - - payloadLen := 256 - firstTime := time.Unix(0, 0) - middleTime := time.Unix(10, 0) - lastTime := time.Unix(20, 0) - dataSize := MaxDocumentSize * 2 - data := make([]byte, dataSize) - rand.Read(data) - reassembles := make([]tcpassembly.Reassembly, dataSize/payloadLen) - indexes := make([]int, dataSize/payloadLen) - timestamps := make([]time.Time, dataSize/payloadLen) - lossBlocks := make([]bool, dataSize/payloadLen) - for i := 0; i < len(reassembles); i++ { - var seen time.Time - if i == 0 { - seen = firstTime - } else if i == len(reassembles)-1 { - seen = lastTime - } else { - seen = middleTime - } - - reassembles[i] = tcpassembly.Reassembly{ - Bytes: data[i*payloadLen : (i+1)*payloadLen], - Skip: 0, - Start: i == 0, - End: i == len(reassembles)-1, - Seen: seen, - } - indexes[i] = i * payloadLen % MaxDocumentSize - timestamps[i] = seen - } - - streamHandler.Reassembled(reassembles) - - var results []ConnectionStream - err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results) - require.NoError(t, err) - assert.Len(t, results, 1) - - completed := false - streamHandler.connection.(*testConnectionHandler).onComplete = func(handler *StreamHandler) { - completed = true - } - streamHandler.ReassemblyComplete() - - err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results) - require.NoError(t, err) - assert.Len(t, results, 2) - for i := 0; i < 2; i++ { - blockLen := MaxDocumentSize / payloadLen - assert.Equal(t, firstTime.Unix(), results[i].ID.Timestamp().Unix()) - assert.Zero(t, results[i].ConnectionID) - assert.Equal(t, i, results[i].DocumentIndex) - assert.Equal(t, data[MaxDocumentSize*i:MaxDocumentSize*(i+1)], results[i].Payload) - assert.Equal(t, indexes[blockLen*i:blockLen*(i+1)], results[i].BlocksIndexes) - assert.Len(t, results[i].BlocksTimestamps, len(timestamps[blockLen*i:blockLen*(i+1)])) // should be compared one by one - assert.Equal(t, lossBlocks[blockLen*i:blockLen*(i+1)], results[i].BlocksLoss) - assert.Len(t, results[i].PatternMatches, 0) - } - - assert.Equal(t, MaxDocumentSize, streamHandler.currentIndex) - assert.Equal(t, firstTime, streamHandler.firstPacketSeen) - assert.Equal(t, lastTime, streamHandler.lastPacketSeen) - assert.Len(t, streamHandler.documentsIDs, 2) - assert.Equal(t, len(data), streamHandler.streamLength) - assert.Len(t, streamHandler.patternMatches, 0) - - assert.Equal(t, true, completed, "completed") - - err = scratch.Free() - require.NoError(t, err, "free scratch") - err = patterns.Close() - require.NoError(t, err, "close stream database") - wrapper.Destroy(t) -} - -func TestReassemblingPatternMatching(t *testing.T) { - wrapper := NewTestStorageWrapper(t) - wrapper.AddCollection(ConnectionStreams) - a, err := hyperscan.ParsePattern("/a{8}/i") - require.NoError(t, err) - a.Id = 0 - a.Flags |= hyperscan.SomLeftMost - b, err := hyperscan.ParsePattern("/b[c]+b/i") - require.NoError(t, err) - b.Id = 1 - b.Flags |= hyperscan.SomLeftMost - d, err := hyperscan.ParsePattern("/[d]+e[d]+/i") - require.NoError(t, err) - d.Id = 2 - d.Flags |= hyperscan.SomLeftMost - - payload := "aaaaaaaa0aaaaaaaaaa0bbbcccbbb0dddeddddedddd" - expected := map[uint][]PatternSlice{ - 0: {{0, 8}, {9, 17}, {10, 18}, {11, 19}}, - 1: {{22, 27}}, - 2: {{30, 38}, {34, 43}}, - } - - patterns, err := hyperscan.NewStreamDatabase(a, b, d) - require.NoError(t, err) - scratch, err := hyperscan.NewScratch(patterns) - require.NoError(t, err) - streamHandler := createTestStreamHandler(wrapper, patterns, scratch) - - seen := time.Unix(0, 0) - streamHandler.Reassembled([]tcpassembly.Reassembly{{ - Bytes: []byte(payload), - Skip: 0, - Start: true, - End: true, - Seen: seen, - }}) - - var results []ConnectionStream - err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results) - require.NoError(t, err) - assert.Len(t, results, 0) - - completed := false - streamHandler.connection.(*testConnectionHandler).onComplete = func(handler *StreamHandler) { - completed = true - } - streamHandler.ReassemblyComplete() - - err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results) - require.NoError(t, err) - assert.Len(t, results, 1) - assert.Equal(t, seen.Unix(), results[0].ID.Timestamp().Unix()) - assert.Zero(t, results[0].ConnectionID) - assert.Equal(t, 0, results[0].DocumentIndex) - assert.Equal(t, []byte(payload), results[0].Payload) - assert.Equal(t, []int{0}, results[0].BlocksIndexes) - assert.Len(t, results[0].BlocksTimestamps, 1) // should be compared one by one - assert.Equal(t, []bool{false}, results[0].BlocksLoss) - assert.Equal(t, expected, results[0].PatternMatches) - - assert.Equal(t, len(payload), streamHandler.currentIndex) - assert.Equal(t, seen, streamHandler.firstPacketSeen) - assert.Equal(t, seen, streamHandler.lastPacketSeen) - assert.Len(t, streamHandler.documentsIDs, 1) - assert.Equal(t, len(payload), streamHandler.streamLength) - - assert.Equal(t, true, completed, "completed") - - err = scratch.Free() - require.NoError(t, err, "free scratch") - err = patterns.Close() - require.NoError(t, err, "close stream database") - wrapper.Destroy(t) -} - -func createTestStreamHandler(wrapper *TestStorageWrapper, patterns hyperscan.StreamDatabase, scratch *hyperscan.Scratch) StreamHandler { - testConnectionHandler := &testConnectionHandler{ - wrapper: wrapper, - patterns: patterns, - } - - srcIP := layers.NewIPEndpoint(net.ParseIP(testSrcIP)) - dstIP := layers.NewIPEndpoint(net.ParseIP(testDstIP)) - srcPort := layers.NewTCPPortEndpoint(srcPort) - dstPort := layers.NewTCPPortEndpoint(dstPort) - - scanner := Scanner{scratch: scratch, version: ZeroRowID} - return NewStreamHandler(testConnectionHandler, StreamFlow{srcIP, dstIP, srcPort, dstPort}, scanner, true) // TODO: test isClient -} - -type testConnectionHandler struct { - wrapper *TestStorageWrapper - patterns hyperscan.StreamDatabase - onComplete func(*StreamHandler) -} - -func (tch *testConnectionHandler) Storage() Storage { - return tch.wrapper.Storage -} - -func (tch *testConnectionHandler) Context() context.Context { - return tch.wrapper.Context -} - -func (tch *testConnectionHandler) PatternsDatabase() hyperscan.StreamDatabase { - return tch.patterns -} - -func (tch *testConnectionHandler) PatternsDatabaseSize() int { - return 8 -} - -func (tch *testConnectionHandler) Complete(handler *StreamHandler) { - tch.onComplete(handler) -} -- cgit v1.2.3-70-g09d2