From db8ff43c5e1595c02e2ba67c3c78f239723f95bd Mon Sep 17 00:00:00 2001 From: Emiliano Ciavatta Date: Fri, 17 Jul 2020 11:12:09 +0200 Subject: Added support for cidr addresses when checking server ip --- application_context.go | 15 +++++++-------- application_context_test.go | 6 +++--- application_router_test.go | 4 ++-- connection_handler.go | 18 ++++++++++-------- connection_handler_test.go | 6 +++--- pcap_importer.go | 16 ++++++++-------- pcap_importer_test.go | 7 ++----- utils.go | 24 ++++++++++++++++++++++++ 8 files changed, 59 insertions(+), 37 deletions(-) diff --git a/application_context.go b/application_context.go index 6960c7d..e4be74d 100644 --- a/application_context.go +++ b/application_context.go @@ -3,13 +3,12 @@ package main import ( "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" - "net" ) type Config struct { - ServerIP string `json:"server_ip" binding:"required,ip" bson:"server_ip"` - FlagRegex string `json:"flag_regex" binding:"required,min=8" bson:"flag_regex"` - AuthRequired bool `json:"auth_required" bson:"auth_required"` + ServerAddress string `json:"server_address" binding:"required,ip|cidr" bson:"server_address"` + FlagRegex string `json:"flag_regex" binding:"required,min=8" bson:"flag_regex"` + AuthRequired bool `json:"auth_required" bson:"auth_required"` } type ApplicationContext struct { @@ -77,11 +76,11 @@ func (sm *ApplicationContext) configure() { if sm.IsConfigured { return } - if sm.Config.ServerIP == "" || sm.Config.FlagRegex == "" { + if sm.Config.ServerAddress == "" || sm.Config.FlagRegex == "" { return } - serverIP := net.ParseIP(sm.Config.ServerIP) - if serverIP == nil { + serverNet := ParseIPNet(sm.Config.ServerAddress) + if serverNet == nil { return } @@ -90,7 +89,7 @@ func (sm *ApplicationContext) configure() { log.WithError(err).Panic("failed to create a RulesManager") } sm.RulesManager = rulesManager - sm.PcapImporter = NewPcapImporter(sm.Storage, serverIP, sm.RulesManager) + sm.PcapImporter = NewPcapImporter(sm.Storage, *serverNet, sm.RulesManager) sm.ServicesController = NewServicesController(sm.Storage) sm.ConnectionsController = NewConnectionsController(sm.Storage, sm.ServicesController) sm.ConnectionStreamsController = NewConnectionStreamsController(sm.Storage) diff --git a/application_context_test.go b/application_context_test.go index 2a94cf6..eed0fd6 100644 --- a/application_context_test.go +++ b/application_context_test.go @@ -19,9 +19,9 @@ func TestCreateApplicationContext(t *testing.T) { assert.Nil(t, appContext.RulesManager) config := Config{ - ServerIP: "10.10.10.10", - FlagRegex: "FLAG{test}", - AuthRequired: true, + ServerAddress: "10.10.10.10", + FlagRegex: "FLAG{test}", + AuthRequired: true, } accounts := gin.Accounts{ "username": "password", diff --git a/application_router_test.go b/application_router_test.go index c8e4474..4225ab9 100644 --- a/application_router_test.go +++ b/application_router_test.go @@ -19,7 +19,7 @@ func TestSetupApplication(t *testing.T) { settings := make(map[string]interface{}) assert.Equal(t, http.StatusServiceUnavailable, toolkit.MakeRequest("GET", "/api/rules", nil).Code) assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("POST", "/setup", settings).Code) - settings["config"] = Config{ServerIP: "1.2.3.4", FlagRegex: "FLAG{test}", AuthRequired: true} + settings["config"] = Config{ServerAddress: "1.2.3.4", FlagRegex: "FLAG{test}", AuthRequired: true} assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("POST", "/setup", settings).Code) settings["accounts"] = gin.Accounts{"username": "password"} assert.Equal(t, http.StatusAccepted, toolkit.MakeRequest("POST", "/setup", settings).Code) @@ -162,7 +162,7 @@ func NewRouterTestToolkit(t *testing.T, withSetup bool) *RouterTestToolkit { if withSetup { settings := gin.H{ - "config": Config{ServerIP: "1.2.3.4", FlagRegex: "FLAG{test}", AuthRequired: false}, + "config": Config{ServerAddress: "1.2.3.4", FlagRegex: "FLAG{test}", AuthRequired: false}, "accounts": gin.Accounts{}, } toolkit.MakeRequest("POST", "/setup", settings) diff --git a/connection_handler.go b/connection_handler.go index ddf5e55..ffe4fac 100644 --- a/connection_handler.go +++ b/connection_handler.go @@ -7,6 +7,7 @@ import ( "github.com/google/gopacket/tcpassembly" log "github.com/sirupsen/logrus" "hash/fnv" + "net" "sync" "time" ) @@ -16,7 +17,7 @@ const initialScannersCapacity = 1024 type BiDirectionalStreamFactory struct { storage Storage - serverIP gopacket.Endpoint + serverNet net.IPNet connections map[StreamFlow]ConnectionHandler mConnections sync.Mutex rulesManager RulesManager @@ -46,12 +47,12 @@ type connectionHandlerImpl struct { otherStream *StreamHandler } -func NewBiDirectionalStreamFactory(storage Storage, serverIP gopacket.Endpoint, +func NewBiDirectionalStreamFactory(storage Storage, serverNet net.IPNet, rulesManager RulesManager) *BiDirectionalStreamFactory { factory := &BiDirectionalStreamFactory{ storage: storage, - serverIP: serverIP, + serverNet: serverNet, connections: make(map[StreamFlow]ConnectionHandler, initialConnectionsCapacity), mConnections: sync.Mutex{}, rulesManager: rulesManager, @@ -128,17 +129,18 @@ func (factory *BiDirectionalStreamFactory) releaseScanner(scanner Scanner) { factory.scanners = append(factory.scanners, scanner) } -func (factory *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream { - flow := StreamFlow{net.Src(), net.Dst(), transport.Src(), transport.Dst()} - invertedFlow := StreamFlow{net.Dst(), net.Src(), transport.Dst(), transport.Src()} +func (factory *BiDirectionalStreamFactory) New(netFlow, transportFlow gopacket.Flow) tcpassembly.Stream { + flow := StreamFlow{netFlow.Src(), netFlow.Dst(), transportFlow.Src(), transportFlow.Dst()} + invertedFlow := StreamFlow{netFlow.Dst(), netFlow.Src(), transportFlow.Dst(), transportFlow.Src()} factory.mConnections.Lock() connection, isPresent := factory.connections[invertedFlow] + isServer := factory.serverNet.Contains(netFlow.Src().Raw()) if isPresent { delete(factory.connections, invertedFlow) } else { var connectionFlow StreamFlow - if net.Src() == factory.serverIP { + if isServer { connectionFlow = invertedFlow } else { connectionFlow = flow @@ -152,7 +154,7 @@ func (factory *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcp } factory.mConnections.Unlock() - streamHandler := NewStreamHandler(connection, flow, factory.takeScanner(), net.Src() != factory.serverIP) + streamHandler := NewStreamHandler(connection, flow, factory.takeScanner(), !isServer) return &streamHandler } diff --git a/connection_handler_test.go b/connection_handler_test.go index 7b00f6e..0bee0ac 100644 --- a/connection_handler_test.go +++ b/connection_handler_test.go @@ -17,7 +17,7 @@ import ( func TestTakeReleaseScanners(t *testing.T) { wrapper := NewTestStorageWrapper(t) - serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP)) + serverNet := ParseIPNet(testDstIP) ruleManager := TestRulesManager{ databaseUpdated: make(chan RulesDatabase), } @@ -25,7 +25,7 @@ func TestTakeReleaseScanners(t *testing.T) { database, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0)) require.NoError(t, err) - factory := NewBiDirectionalStreamFactory(wrapper.Storage, serverIP, &ruleManager) + factory := NewBiDirectionalStreamFactory(wrapper.Storage, *serverNet, &ruleManager) version := NewRowID() ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} time.Sleep(10 * time.Millisecond) @@ -88,7 +88,7 @@ func TestConnectionFactory(t *testing.T) { database, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0)) require.NoError(t, err) - factory := NewBiDirectionalStreamFactory(wrapper.Storage, serverIP, &ruleManager) + factory := NewBiDirectionalStreamFactory(wrapper.Storage, *ParseIPNet(testDstIP), &ruleManager) version := NewRowID() ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version} time.Sleep(10 * time.Millisecond) diff --git a/pcap_importer.go b/pcap_importer.go index 9d3f5bc..cd6fdfa 100644 --- a/pcap_importer.go +++ b/pcap_importer.go @@ -29,7 +29,7 @@ type PcapImporter struct { sessions map[string]ImportingSession mAssemblers sync.Mutex mSessions sync.Mutex - serverIP gopacket.Endpoint + serverNet net.IPNet } type ImportingSession struct { @@ -47,9 +47,8 @@ type ImportingSession struct { type flowCount [2]int -func NewPcapImporter(storage Storage, serverIP net.IP, rulesManager RulesManager) *PcapImporter { - serverEndpoint := layers.NewIPEndpoint(serverIP) - streamPool := tcpassembly.NewStreamPool(NewBiDirectionalStreamFactory(storage, serverEndpoint, rulesManager)) +func NewPcapImporter(storage Storage, serverNet net.IPNet, rulesManager RulesManager) *PcapImporter { + streamPool := tcpassembly.NewStreamPool(NewBiDirectionalStreamFactory(storage, serverNet, rulesManager)) var result []ImportingSession if err := storage.Find(ImportingSessions).All(&result); err != nil { @@ -67,7 +66,7 @@ func NewPcapImporter(storage Storage, serverIP net.IP, rulesManager RulesManager sessions: sessions, mAssemblers: sync.Mutex{}, mSessions: sync.Mutex{}, - serverIP: serverEndpoint, + serverNet: serverNet, } } @@ -198,8 +197,9 @@ func (pi *PcapImporter) parsePcap(session ImportingSession, fileName string, ctx tcp := packet.TransportLayer().(*layers.TCP) var servicePort uint16 var index int - isDstServer := packet.NetworkLayer().NetworkFlow().Dst() == pi.serverIP - isSrcServer := packet.NetworkLayer().NetworkFlow().Src() == pi.serverIP + + isDstServer := pi.serverNet.Contains(packet.NetworkLayer().NetworkFlow().Dst().Raw()) + isSrcServer := pi.serverNet.Contains(packet.NetworkLayer().NetworkFlow().Src().Raw()) if isDstServer && !isSrcServer { servicePort = uint16(tcp.DstPort) index = 0 @@ -208,7 +208,7 @@ func (pi *PcapImporter) parsePcap(session ImportingSession, fileName string, ctx index = 1 } else { session.InvalidPackets++ - // continue // workaround to process packets when services have multiple ips + continue } fCount, isPresent := session.PacketsPerService[servicePort] if !isPresent { diff --git a/pcap_importer_test.go b/pcap_importer_test.go index 6f9d4a5..74dd2cc 100644 --- a/pcap_importer_test.go +++ b/pcap_importer_test.go @@ -4,12 +4,10 @@ import ( "bufio" "fmt" "github.com/google/gopacket" - "github.com/google/gopacket/layers" "github.com/google/gopacket/tcpassembly" "github.com/google/gopacket/tcpassembly/tcpreader" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "net" "os" "sync" "testing" @@ -99,10 +97,9 @@ func TestImportNoTcpPackets(t *testing.T) { wrapper.Destroy(t) } -func newTestPcapImporter(wrapper *TestStorageWrapper, serverIP string) *PcapImporter { +func newTestPcapImporter(wrapper *TestStorageWrapper, serverAddress string) *PcapImporter { wrapper.AddCollection(ImportingSessions) - serverEndpoint := layers.NewIPEndpoint(net.ParseIP(serverIP)) streamPool := tcpassembly.NewStreamPool(&testStreamFactory{}) return &PcapImporter{ @@ -112,7 +109,7 @@ func newTestPcapImporter(wrapper *TestStorageWrapper, serverIP string) *PcapImpo sessions: make(map[string]ImportingSession), mAssemblers: sync.Mutex{}, mSessions: sync.Mutex{}, - serverIP: serverEndpoint, + serverNet: *ParseIPNet(serverAddress), } } diff --git a/utils.go b/utils.go index a015b75..a14fdca 100644 --- a/utils.go +++ b/utils.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" "go.mongodb.org/mongo-driver/bson/primitive" "io" + "net" "os" "time" ) @@ -127,3 +128,26 @@ func CopyFile(dst, src string) error { } return out.Close() } + +func ParseIPNet(address string) *net.IPNet { + _, network, err := net.ParseCIDR(address) + if err != nil { + ip := net.ParseIP(address) + if ip == nil { + return nil + } + + size := 0 + if ip.To4() != nil { + size = 32 + } else { + size = 128 + } + network = &net.IPNet{ + IP: ip, + Mask: net.CIDRMask(size, size), + } + } + + return network +} -- cgit v1.2.3-70-g09d2