aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--application_context.go15
-rw-r--r--application_context_test.go6
-rw-r--r--application_router_test.go4
-rw-r--r--connection_handler.go18
-rw-r--r--connection_handler_test.go6
-rw-r--r--pcap_importer.go16
-rw-r--r--pcap_importer_test.go7
-rw-r--r--utils.go24
8 files changed, 59 insertions, 37 deletions
diff --git a/application_context.go b/application_context.go
index 6960c7d..e4be74d 100644
--- a/application_context.go
+++ b/application_context.go
@@ -3,13 +3,12 @@ package main
import (
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
- "net"
)
type Config struct {
- ServerIP string `json:"server_ip" binding:"required,ip" bson:"server_ip"`
- FlagRegex string `json:"flag_regex" binding:"required,min=8" bson:"flag_regex"`
- AuthRequired bool `json:"auth_required" bson:"auth_required"`
+ ServerAddress string `json:"server_address" binding:"required,ip|cidr" bson:"server_address"`
+ FlagRegex string `json:"flag_regex" binding:"required,min=8" bson:"flag_regex"`
+ AuthRequired bool `json:"auth_required" bson:"auth_required"`
}
type ApplicationContext struct {
@@ -77,11 +76,11 @@ func (sm *ApplicationContext) configure() {
if sm.IsConfigured {
return
}
- if sm.Config.ServerIP == "" || sm.Config.FlagRegex == "" {
+ if sm.Config.ServerAddress == "" || sm.Config.FlagRegex == "" {
return
}
- serverIP := net.ParseIP(sm.Config.ServerIP)
- if serverIP == nil {
+ serverNet := ParseIPNet(sm.Config.ServerAddress)
+ if serverNet == nil {
return
}
@@ -90,7 +89,7 @@ func (sm *ApplicationContext) configure() {
log.WithError(err).Panic("failed to create a RulesManager")
}
sm.RulesManager = rulesManager
- sm.PcapImporter = NewPcapImporter(sm.Storage, serverIP, sm.RulesManager)
+ sm.PcapImporter = NewPcapImporter(sm.Storage, *serverNet, sm.RulesManager)
sm.ServicesController = NewServicesController(sm.Storage)
sm.ConnectionsController = NewConnectionsController(sm.Storage, sm.ServicesController)
sm.ConnectionStreamsController = NewConnectionStreamsController(sm.Storage)
diff --git a/application_context_test.go b/application_context_test.go
index 2a94cf6..eed0fd6 100644
--- a/application_context_test.go
+++ b/application_context_test.go
@@ -19,9 +19,9 @@ func TestCreateApplicationContext(t *testing.T) {
assert.Nil(t, appContext.RulesManager)
config := Config{
- ServerIP: "10.10.10.10",
- FlagRegex: "FLAG{test}",
- AuthRequired: true,
+ ServerAddress: "10.10.10.10",
+ FlagRegex: "FLAG{test}",
+ AuthRequired: true,
}
accounts := gin.Accounts{
"username": "password",
diff --git a/application_router_test.go b/application_router_test.go
index c8e4474..4225ab9 100644
--- a/application_router_test.go
+++ b/application_router_test.go
@@ -19,7 +19,7 @@ func TestSetupApplication(t *testing.T) {
settings := make(map[string]interface{})
assert.Equal(t, http.StatusServiceUnavailable, toolkit.MakeRequest("GET", "/api/rules", nil).Code)
assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("POST", "/setup", settings).Code)
- settings["config"] = Config{ServerIP: "1.2.3.4", FlagRegex: "FLAG{test}", AuthRequired: true}
+ settings["config"] = Config{ServerAddress: "1.2.3.4", FlagRegex: "FLAG{test}", AuthRequired: true}
assert.Equal(t, http.StatusBadRequest, toolkit.MakeRequest("POST", "/setup", settings).Code)
settings["accounts"] = gin.Accounts{"username": "password"}
assert.Equal(t, http.StatusAccepted, toolkit.MakeRequest("POST", "/setup", settings).Code)
@@ -162,7 +162,7 @@ func NewRouterTestToolkit(t *testing.T, withSetup bool) *RouterTestToolkit {
if withSetup {
settings := gin.H{
- "config": Config{ServerIP: "1.2.3.4", FlagRegex: "FLAG{test}", AuthRequired: false},
+ "config": Config{ServerAddress: "1.2.3.4", FlagRegex: "FLAG{test}", AuthRequired: false},
"accounts": gin.Accounts{},
}
toolkit.MakeRequest("POST", "/setup", settings)
diff --git a/connection_handler.go b/connection_handler.go
index ddf5e55..ffe4fac 100644
--- a/connection_handler.go
+++ b/connection_handler.go
@@ -7,6 +7,7 @@ import (
"github.com/google/gopacket/tcpassembly"
log "github.com/sirupsen/logrus"
"hash/fnv"
+ "net"
"sync"
"time"
)
@@ -16,7 +17,7 @@ const initialScannersCapacity = 1024
type BiDirectionalStreamFactory struct {
storage Storage
- serverIP gopacket.Endpoint
+ serverNet net.IPNet
connections map[StreamFlow]ConnectionHandler
mConnections sync.Mutex
rulesManager RulesManager
@@ -46,12 +47,12 @@ type connectionHandlerImpl struct {
otherStream *StreamHandler
}
-func NewBiDirectionalStreamFactory(storage Storage, serverIP gopacket.Endpoint,
+func NewBiDirectionalStreamFactory(storage Storage, serverNet net.IPNet,
rulesManager RulesManager) *BiDirectionalStreamFactory {
factory := &BiDirectionalStreamFactory{
storage: storage,
- serverIP: serverIP,
+ serverNet: serverNet,
connections: make(map[StreamFlow]ConnectionHandler, initialConnectionsCapacity),
mConnections: sync.Mutex{},
rulesManager: rulesManager,
@@ -128,17 +129,18 @@ func (factory *BiDirectionalStreamFactory) releaseScanner(scanner Scanner) {
factory.scanners = append(factory.scanners, scanner)
}
-func (factory *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream {
- flow := StreamFlow{net.Src(), net.Dst(), transport.Src(), transport.Dst()}
- invertedFlow := StreamFlow{net.Dst(), net.Src(), transport.Dst(), transport.Src()}
+func (factory *BiDirectionalStreamFactory) New(netFlow, transportFlow gopacket.Flow) tcpassembly.Stream {
+ flow := StreamFlow{netFlow.Src(), netFlow.Dst(), transportFlow.Src(), transportFlow.Dst()}
+ invertedFlow := StreamFlow{netFlow.Dst(), netFlow.Src(), transportFlow.Dst(), transportFlow.Src()}
factory.mConnections.Lock()
connection, isPresent := factory.connections[invertedFlow]
+ isServer := factory.serverNet.Contains(netFlow.Src().Raw())
if isPresent {
delete(factory.connections, invertedFlow)
} else {
var connectionFlow StreamFlow
- if net.Src() == factory.serverIP {
+ if isServer {
connectionFlow = invertedFlow
} else {
connectionFlow = flow
@@ -152,7 +154,7 @@ func (factory *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcp
}
factory.mConnections.Unlock()
- streamHandler := NewStreamHandler(connection, flow, factory.takeScanner(), net.Src() != factory.serverIP)
+ streamHandler := NewStreamHandler(connection, flow, factory.takeScanner(), !isServer)
return &streamHandler
}
diff --git a/connection_handler_test.go b/connection_handler_test.go
index 7b00f6e..0bee0ac 100644
--- a/connection_handler_test.go
+++ b/connection_handler_test.go
@@ -17,7 +17,7 @@ import (
func TestTakeReleaseScanners(t *testing.T) {
wrapper := NewTestStorageWrapper(t)
- serverIP := layers.NewIPEndpoint(net.ParseIP(testDstIP))
+ serverNet := ParseIPNet(testDstIP)
ruleManager := TestRulesManager{
databaseUpdated: make(chan RulesDatabase),
}
@@ -25,7 +25,7 @@ func TestTakeReleaseScanners(t *testing.T) {
database, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0))
require.NoError(t, err)
- factory := NewBiDirectionalStreamFactory(wrapper.Storage, serverIP, &ruleManager)
+ factory := NewBiDirectionalStreamFactory(wrapper.Storage, *serverNet, &ruleManager)
version := NewRowID()
ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
time.Sleep(10 * time.Millisecond)
@@ -88,7 +88,7 @@ func TestConnectionFactory(t *testing.T) {
database, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0))
require.NoError(t, err)
- factory := NewBiDirectionalStreamFactory(wrapper.Storage, serverIP, &ruleManager)
+ factory := NewBiDirectionalStreamFactory(wrapper.Storage, *ParseIPNet(testDstIP), &ruleManager)
version := NewRowID()
ruleManager.DatabaseUpdateChannel() <- RulesDatabase{database, 0, version}
time.Sleep(10 * time.Millisecond)
diff --git a/pcap_importer.go b/pcap_importer.go
index 9d3f5bc..cd6fdfa 100644
--- a/pcap_importer.go
+++ b/pcap_importer.go
@@ -29,7 +29,7 @@ type PcapImporter struct {
sessions map[string]ImportingSession
mAssemblers sync.Mutex
mSessions sync.Mutex
- serverIP gopacket.Endpoint
+ serverNet net.IPNet
}
type ImportingSession struct {
@@ -47,9 +47,8 @@ type ImportingSession struct {
type flowCount [2]int
-func NewPcapImporter(storage Storage, serverIP net.IP, rulesManager RulesManager) *PcapImporter {
- serverEndpoint := layers.NewIPEndpoint(serverIP)
- streamPool := tcpassembly.NewStreamPool(NewBiDirectionalStreamFactory(storage, serverEndpoint, rulesManager))
+func NewPcapImporter(storage Storage, serverNet net.IPNet, rulesManager RulesManager) *PcapImporter {
+ streamPool := tcpassembly.NewStreamPool(NewBiDirectionalStreamFactory(storage, serverNet, rulesManager))
var result []ImportingSession
if err := storage.Find(ImportingSessions).All(&result); err != nil {
@@ -67,7 +66,7 @@ func NewPcapImporter(storage Storage, serverIP net.IP, rulesManager RulesManager
sessions: sessions,
mAssemblers: sync.Mutex{},
mSessions: sync.Mutex{},
- serverIP: serverEndpoint,
+ serverNet: serverNet,
}
}
@@ -198,8 +197,9 @@ func (pi *PcapImporter) parsePcap(session ImportingSession, fileName string, ctx
tcp := packet.TransportLayer().(*layers.TCP)
var servicePort uint16
var index int
- isDstServer := packet.NetworkLayer().NetworkFlow().Dst() == pi.serverIP
- isSrcServer := packet.NetworkLayer().NetworkFlow().Src() == pi.serverIP
+
+ isDstServer := pi.serverNet.Contains(packet.NetworkLayer().NetworkFlow().Dst().Raw())
+ isSrcServer := pi.serverNet.Contains(packet.NetworkLayer().NetworkFlow().Src().Raw())
if isDstServer && !isSrcServer {
servicePort = uint16(tcp.DstPort)
index = 0
@@ -208,7 +208,7 @@ func (pi *PcapImporter) parsePcap(session ImportingSession, fileName string, ctx
index = 1
} else {
session.InvalidPackets++
- // continue // workaround to process packets when services have multiple ips
+ continue
}
fCount, isPresent := session.PacketsPerService[servicePort]
if !isPresent {
diff --git a/pcap_importer_test.go b/pcap_importer_test.go
index 6f9d4a5..74dd2cc 100644
--- a/pcap_importer_test.go
+++ b/pcap_importer_test.go
@@ -4,12 +4,10 @@ import (
"bufio"
"fmt"
"github.com/google/gopacket"
- "github.com/google/gopacket/layers"
"github.com/google/gopacket/tcpassembly"
"github.com/google/gopacket/tcpassembly/tcpreader"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "net"
"os"
"sync"
"testing"
@@ -99,10 +97,9 @@ func TestImportNoTcpPackets(t *testing.T) {
wrapper.Destroy(t)
}
-func newTestPcapImporter(wrapper *TestStorageWrapper, serverIP string) *PcapImporter {
+func newTestPcapImporter(wrapper *TestStorageWrapper, serverAddress string) *PcapImporter {
wrapper.AddCollection(ImportingSessions)
- serverEndpoint := layers.NewIPEndpoint(net.ParseIP(serverIP))
streamPool := tcpassembly.NewStreamPool(&testStreamFactory{})
return &PcapImporter{
@@ -112,7 +109,7 @@ func newTestPcapImporter(wrapper *TestStorageWrapper, serverIP string) *PcapImpo
sessions: make(map[string]ImportingSession),
mAssemblers: sync.Mutex{},
mSessions: sync.Mutex{},
- serverIP: serverEndpoint,
+ serverNet: *ParseIPNet(serverAddress),
}
}
diff --git a/utils.go b/utils.go
index a015b75..a14fdca 100644
--- a/utils.go
+++ b/utils.go
@@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson/primitive"
"io"
+ "net"
"os"
"time"
)
@@ -127,3 +128,26 @@ func CopyFile(dst, src string) error {
}
return out.Close()
}
+
+func ParseIPNet(address string) *net.IPNet {
+ _, network, err := net.ParseCIDR(address)
+ if err != nil {
+ ip := net.ParseIP(address)
+ if ip == nil {
+ return nil
+ }
+
+ size := 0
+ if ip.To4() != nil {
+ size = 32
+ } else {
+ size = 128
+ }
+ network = &net.IPNet{
+ IP: ip,
+ Mask: net.CIDRMask(size, size),
+ }
+ }
+
+ return network
+}