diff options
Diffstat (limited to 'connection_handler.go')
-rw-r--r-- | connection_handler.go | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/connection_handler.go b/connection_handler.go index ddf5e55..ffe4fac 100644 --- a/connection_handler.go +++ b/connection_handler.go @@ -7,6 +7,7 @@ import ( "github.com/google/gopacket/tcpassembly" log "github.com/sirupsen/logrus" "hash/fnv" + "net" "sync" "time" ) @@ -16,7 +17,7 @@ const initialScannersCapacity = 1024 type BiDirectionalStreamFactory struct { storage Storage - serverIP gopacket.Endpoint + serverNet net.IPNet connections map[StreamFlow]ConnectionHandler mConnections sync.Mutex rulesManager RulesManager @@ -46,12 +47,12 @@ type connectionHandlerImpl struct { otherStream *StreamHandler } -func NewBiDirectionalStreamFactory(storage Storage, serverIP gopacket.Endpoint, +func NewBiDirectionalStreamFactory(storage Storage, serverNet net.IPNet, rulesManager RulesManager) *BiDirectionalStreamFactory { factory := &BiDirectionalStreamFactory{ storage: storage, - serverIP: serverIP, + serverNet: serverNet, connections: make(map[StreamFlow]ConnectionHandler, initialConnectionsCapacity), mConnections: sync.Mutex{}, rulesManager: rulesManager, @@ -128,17 +129,18 @@ func (factory *BiDirectionalStreamFactory) releaseScanner(scanner Scanner) { factory.scanners = append(factory.scanners, scanner) } -func (factory *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream { - flow := StreamFlow{net.Src(), net.Dst(), transport.Src(), transport.Dst()} - invertedFlow := StreamFlow{net.Dst(), net.Src(), transport.Dst(), transport.Src()} +func (factory *BiDirectionalStreamFactory) New(netFlow, transportFlow gopacket.Flow) tcpassembly.Stream { + flow := StreamFlow{netFlow.Src(), netFlow.Dst(), transportFlow.Src(), transportFlow.Dst()} + invertedFlow := StreamFlow{netFlow.Dst(), netFlow.Src(), transportFlow.Dst(), transportFlow.Src()} factory.mConnections.Lock() connection, isPresent := factory.connections[invertedFlow] + isServer := factory.serverNet.Contains(netFlow.Src().Raw()) if isPresent { delete(factory.connections, invertedFlow) } else { var connectionFlow StreamFlow - if net.Src() == factory.serverIP { + if isServer { connectionFlow = invertedFlow } else { connectionFlow = flow @@ -152,7 +154,7 @@ func (factory *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcp } factory.mConnections.Unlock() - streamHandler := NewStreamHandler(connection, flow, factory.takeScanner(), net.Src() != factory.serverIP) + streamHandler := NewStreamHandler(connection, flow, factory.takeScanner(), !isServer) return &streamHandler } |