aboutsummaryrefslogtreecommitdiff
path: root/connection_handler.go
diff options
context:
space:
mode:
Diffstat (limited to 'connection_handler.go')
-rw-r--r--connection_handler.go18
1 files changed, 10 insertions, 8 deletions
diff --git a/connection_handler.go b/connection_handler.go
index ddf5e55..ffe4fac 100644
--- a/connection_handler.go
+++ b/connection_handler.go
@@ -7,6 +7,7 @@ import (
"github.com/google/gopacket/tcpassembly"
log "github.com/sirupsen/logrus"
"hash/fnv"
+ "net"
"sync"
"time"
)
@@ -16,7 +17,7 @@ const initialScannersCapacity = 1024
type BiDirectionalStreamFactory struct {
storage Storage
- serverIP gopacket.Endpoint
+ serverNet net.IPNet
connections map[StreamFlow]ConnectionHandler
mConnections sync.Mutex
rulesManager RulesManager
@@ -46,12 +47,12 @@ type connectionHandlerImpl struct {
otherStream *StreamHandler
}
-func NewBiDirectionalStreamFactory(storage Storage, serverIP gopacket.Endpoint,
+func NewBiDirectionalStreamFactory(storage Storage, serverNet net.IPNet,
rulesManager RulesManager) *BiDirectionalStreamFactory {
factory := &BiDirectionalStreamFactory{
storage: storage,
- serverIP: serverIP,
+ serverNet: serverNet,
connections: make(map[StreamFlow]ConnectionHandler, initialConnectionsCapacity),
mConnections: sync.Mutex{},
rulesManager: rulesManager,
@@ -128,17 +129,18 @@ func (factory *BiDirectionalStreamFactory) releaseScanner(scanner Scanner) {
factory.scanners = append(factory.scanners, scanner)
}
-func (factory *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream {
- flow := StreamFlow{net.Src(), net.Dst(), transport.Src(), transport.Dst()}
- invertedFlow := StreamFlow{net.Dst(), net.Src(), transport.Dst(), transport.Src()}
+func (factory *BiDirectionalStreamFactory) New(netFlow, transportFlow gopacket.Flow) tcpassembly.Stream {
+ flow := StreamFlow{netFlow.Src(), netFlow.Dst(), transportFlow.Src(), transportFlow.Dst()}
+ invertedFlow := StreamFlow{netFlow.Dst(), netFlow.Src(), transportFlow.Dst(), transportFlow.Src()}
factory.mConnections.Lock()
connection, isPresent := factory.connections[invertedFlow]
+ isServer := factory.serverNet.Contains(netFlow.Src().Raw())
if isPresent {
delete(factory.connections, invertedFlow)
} else {
var connectionFlow StreamFlow
- if net.Src() == factory.serverIP {
+ if isServer {
connectionFlow = invertedFlow
} else {
connectionFlow = flow
@@ -152,7 +154,7 @@ func (factory *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcp
}
factory.mConnections.Unlock()
- streamHandler := NewStreamHandler(connection, flow, factory.takeScanner(), net.Src() != factory.serverIP)
+ streamHandler := NewStreamHandler(connection, flow, factory.takeScanner(), !isServer)
return &streamHandler
}