aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md2
-rw-r--r--caronte.go8
-rw-r--r--caronte_test.go50
-rw-r--r--connection_handler.go200
-rw-r--r--go.mod1
-rw-r--r--pcap_importer.go9
-rw-r--r--storage.go57
-rw-r--r--storage_test.go44
-rw-r--r--stream_factory.go29
-rw-r--r--stream_handler.go310
-rw-r--r--stream_handler_test.go193
11 files changed, 631 insertions, 272 deletions
diff --git a/README.md b/README.md
index 888b88d..9fc1fbb 100644
--- a/README.md
+++ b/README.md
@@ -2,6 +2,7 @@
[![Build Status](https://travis-ci.com/eciavatta/caronte.svg?branch=develop)](https://travis-ci.com/eciavatta/caronte)
[![codecov](https://codecov.io/gh/eciavatta/caronte/branch/develop/graph/badge.svg)](https://codecov.io/gh/eciavatta/caronte)
+[![GPL License][license-shield]][license-url]
<img align="left" src="https://divinacommedia.weebly.com/uploads/5/5/2/3/5523249/1299707879.jpg">
Caronte is a tool to analyze the network flow during capture the flag events of type attack/defence.
@@ -9,4 +10,3 @@ It reassembles TCP packets captured in pcap files to rebuild TCP connections, an
The patterns can be defined as regex or using protocol specific rules.
The connection flows are saved into a database and can be visualized with the web application. REST API are also provided.
-Packets can be captured locally on the same machine or can be imported remotely. The streams of bytes extracted from the TCP payload of packets are processed by [Hyperscan](https://github.com/intel/hyperscan), an high-performance regular expression matching library. // TODO \ No newline at end of file
diff --git a/caronte.go b/caronte.go
index 96bb281..6b33318 100644
--- a/caronte.go
+++ b/caronte.go
@@ -2,17 +2,19 @@ package main
import (
"fmt"
+ "net"
)
func main() {
- // testStorage()
- storage := NewStorage("localhost", 27017, "testing")
+ // pattern.Flags |= hyperscan.SomLeftMost
+
+ storage := NewMongoStorage("localhost", 27017, "testing")
err := storage.Connect(nil)
if err != nil {
panic(err)
}
- importer := NewPcapImporter(&storage, "10.10.10.10")
+ importer := NewPcapImporter(storage, net.ParseIP("10.10.10.10"))
sessionId, err := importer.ImportPcap("capture_00459_20190627165500.pcap")
if err != nil {
diff --git a/caronte_test.go b/caronte_test.go
new file mode 100644
index 0000000..cbf867d
--- /dev/null
+++ b/caronte_test.go
@@ -0,0 +1,50 @@
+package main
+
+import (
+ "context"
+ "crypto/sha256"
+ "fmt"
+ "go.mongodb.org/mongo-driver/mongo"
+ "go.mongodb.org/mongo-driver/mongo/options"
+ "os"
+ "testing"
+ "time"
+)
+
+var storage Storage
+var testContext context.Context
+
+func TestMain(m *testing.M) {
+ mongoHost, ok := os.LookupEnv("MONGO_HOST")
+ if !ok {
+ mongoHost = "localhost"
+ }
+ mongoPort, ok := os.LookupEnv("MONGO_PORT")
+ if !ok {
+ mongoPort = "27017"
+ }
+
+ uniqueDatabaseName := sha256.Sum256([]byte(time.Now().String()))
+
+ client, err := mongo.NewClient(options.Client().ApplyURI(fmt.Sprintf("mongodb://%s:%v", mongoHost, mongoPort)))
+ if err != nil {
+ panic("failed to create mongo client")
+ }
+
+ db := client.Database(fmt.Sprintf("%x", uniqueDatabaseName[:31]))
+ mongoStorage := MongoStorage{
+ client: client,
+ collections: map[string]*mongo.Collection{testCollection: db.Collection(testCollection)},
+ }
+
+ testContext, _ = context.WithTimeout(context.Background(), 10 * time.Second)
+
+ err = mongoStorage.Connect(nil)
+ if err != nil {
+ panic(err)
+ }
+ storage = &mongoStorage
+
+ exitCode := m.Run()
+ os.Exit(exitCode)
+}
diff --git a/connection_handler.go b/connection_handler.go
new file mode 100644
index 0000000..36dca6d
--- /dev/null
+++ b/connection_handler.go
@@ -0,0 +1,200 @@
+package main
+
+import (
+ "context"
+ "encoding/binary"
+ "fmt"
+ "github.com/flier/gohs/hyperscan"
+ "github.com/google/gopacket"
+ "github.com/google/gopacket/tcpassembly"
+ "log"
+ "sync"
+ "time"
+)
+
+type BiDirectionalStreamFactory struct {
+ storage Storage
+ serverIp gopacket.Endpoint
+ connections map[StreamKey]ConnectionHandler
+ mConnections sync.Mutex
+ patterns hyperscan.StreamDatabase
+ mPatterns sync.Mutex
+ scratches []*hyperscan.Scratch
+}
+
+type StreamKey [4]gopacket.Endpoint
+
+type ConnectionHandler interface {
+ Complete(handler *StreamHandler)
+ Storage() Storage
+ Context() context.Context
+ Patterns() hyperscan.StreamDatabase
+}
+
+type connectionHandlerImpl struct {
+ storage Storage
+ net, transport gopacket.Flow
+ initiator StreamKey
+ connectionKey string
+ mComplete sync.Mutex
+ otherStream *StreamHandler
+ context context.Context
+ patterns hyperscan.StreamDatabase
+}
+
+func (factory *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream {
+ key := StreamKey{net.Src(), net.Dst(), transport.Src(), transport.Dst()}
+ invertedKey := StreamKey{net.Dst(), net.Src(), transport.Dst(), transport.Src()}
+
+ factory.mConnections.Lock()
+ connection, isPresent := factory.connections[invertedKey]
+ if isPresent {
+ delete(factory.connections, invertedKey)
+ } else {
+ var initiator StreamKey
+ if net.Src() == factory.serverIp {
+ initiator = invertedKey
+ } else {
+ initiator = key
+ }
+ connection = &connectionHandlerImpl{
+ storage: factory.storage,
+ net: net,
+ transport: transport,
+ initiator: initiator,
+ mComplete: sync.Mutex{},
+ context: context.Background(),
+ patterns : factory.patterns,
+ }
+ factory.connections[key] = connection
+ }
+ factory.mConnections.Unlock()
+
+ streamHandler := NewStreamHandler(connection, key, factory.takeScratch())
+
+ return &streamHandler
+}
+
+func (factory *BiDirectionalStreamFactory) UpdatePatternsDatabase(database hyperscan.StreamDatabase) {
+ factory.mPatterns.Lock()
+ factory.patterns = database
+
+ for _, s := range factory.scratches {
+ err := s.Realloc(database)
+ if err != nil {
+ fmt.Println("failed to realloc an existing scratch")
+ }
+ }
+
+ factory.mPatterns.Unlock()
+}
+
+func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) {
+ ch.mComplete.Lock()
+ if ch.otherStream == nil {
+ ch.otherStream = handler
+ ch.mComplete.Unlock()
+ return
+ }
+ ch.mComplete.Unlock()
+
+ var startedAt, closedAt time.Time
+ if handler.firstPacketSeen.Before(ch.otherStream.firstPacketSeen) {
+ startedAt = handler.firstPacketSeen
+ } else {
+ startedAt = ch.otherStream.firstPacketSeen
+ }
+
+ if handler.lastPacketSeen.After(ch.otherStream.lastPacketSeen) {
+ closedAt = handler.lastPacketSeen
+ } else {
+ closedAt = ch.otherStream.lastPacketSeen
+ }
+
+ var client, server *StreamHandler
+ if handler.streamKey == ch.initiator {
+ client = handler
+ server = ch.otherStream
+ } else {
+ client = ch.otherStream
+ server = handler
+ }
+
+ ch.generateConnectionKey(startedAt)
+
+ _, err := ch.storage.InsertOne(ch.context, "connections", OrderedDocument{
+ {"_id", ch.connectionKey},
+ {"ip_src", ch.initiator[0].String()},
+ {"ip_dst", ch.initiator[1].String()},
+ {"port_src", ch.initiator[2].String()},
+ {"port_dst", ch.initiator[3].String()},
+ {"started_at", startedAt},
+ {"closed_at", closedAt},
+ {"client_bytes", client.streamLength},
+ {"server_bytes", server.streamLength},
+ {"client_documents", len(client.documentsKeys)},
+ {"server_documents", len(server.documentsKeys)},
+ {"processed_at", time.Now()},
+ })
+ if err != nil {
+ log.Println("error inserting document on collection connections with _id = ", ch.connectionKey)
+ }
+
+ streamsIds := append(client.documentsKeys, server.documentsKeys...)
+ n, err := ch.storage.UpdateOne(ch.context, "connection_streams",
+ UnorderedDocument{"_id": UnorderedDocument{"$in": streamsIds}},
+ UnorderedDocument{"connection_id": ch.connectionKey},
+ false)
+ if err != nil {
+ log.Println("failed to update connection streams", err)
+ }
+ if n != len(streamsIds) {
+ log.Println("failed to update all connections streams")
+ }
+}
+
+func (ch *connectionHandlerImpl) Storage() Storage {
+ return ch.storage
+}
+
+func (ch *connectionHandlerImpl) Context() context.Context {
+ return ch.context
+}
+
+func (ch *connectionHandlerImpl) Patterns() hyperscan.StreamDatabase {
+ return ch.patterns
+}
+
+func (ch *connectionHandlerImpl) generateConnectionKey(firstPacketSeen time.Time) {
+ hash := make([]byte, 16)
+ binary.BigEndian.PutUint64(hash, uint64(firstPacketSeen.UnixNano()))
+ binary.BigEndian.PutUint64(hash[8:], ch.net.FastHash()^ch.transport.FastHash())
+
+ ch.connectionKey = fmt.Sprintf("%x", hash)
+}
+
+func (factory *BiDirectionalStreamFactory) takeScratch() *hyperscan.Scratch {
+ factory.mPatterns.Lock()
+ defer factory.mPatterns.Unlock()
+
+ if len(factory.scratches) == 0 {
+ scratch, err := hyperscan.NewScratch(factory.patterns)
+ if err != nil {
+ fmt.Println("failed to alloc a new scratch")
+ }
+
+ return scratch
+ }
+
+ index := len(factory.scratches) - 1
+ scratch := factory.scratches[index]
+ factory.scratches = factory.scratches[:index]
+
+ return scratch
+}
+
+func (factory *BiDirectionalStreamFactory) releaseScratch(scratch *hyperscan.Scratch) {
+ factory.mPatterns.Lock()
+ factory.scratches = append(factory.scratches, scratch)
+ factory.mPatterns.Unlock()
+}
diff --git a/go.mod b/go.mod
index 5ea26e3..66af307 100644
--- a/go.mod
+++ b/go.mod
@@ -5,6 +5,7 @@ go 1.14
require (
github.com/flier/gohs v1.0.0
github.com/google/gopacket v1.1.17
+ github.com/stretchr/testify v1.3.0
go.mongodb.org/mongo-driver v1.3.1
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3
)
diff --git a/pcap_importer.go b/pcap_importer.go
index 9428b29..c6260de 100644
--- a/pcap_importer.go
+++ b/pcap_importer.go
@@ -24,7 +24,7 @@ const importedPcapsCollectionName = "imported_pcaps"
type PcapImporter struct {
- storage *Storage
+ storage Storage
streamPool *tcpassembly.StreamPool
assemblers []*tcpassembly.Assembler
sessions map[string]context.CancelFunc
@@ -36,10 +36,11 @@ type PcapImporter struct {
type flowCount [2]int
-func NewPcapImporter(storage *Storage, serverIp string) *PcapImporter {
+func NewPcapImporter(storage Storage, serverIp net.IP) *PcapImporter {
+ serverEndpoint := layers.NewIPEndpoint(serverIp)
streamFactory := &BiDirectionalStreamFactory{
storage: storage,
- serverIp: serverIp,
+ serverIp: serverEndpoint,
}
streamPool := tcpassembly.NewStreamPool(streamFactory)
@@ -50,7 +51,7 @@ func NewPcapImporter(storage *Storage, serverIp string) *PcapImporter {
sessions: make(map[string]context.CancelFunc),
mAssemblers: sync.Mutex{},
mSessions: sync.Mutex{},
- serverIp: layers.NewIPEndpoint(net.ParseIP(serverIp)),
+ serverIp: serverEndpoint,
}
}
diff --git a/storage.go b/storage.go
index e8f6645..ea24780 100644
--- a/storage.go
+++ b/storage.go
@@ -14,7 +14,13 @@ import (
const defaultConnectionTimeout = 10*time.Second
const defaultOperationTimeout = 3*time.Second
-type Storage struct {
+type Storage interface {
+ InsertOne(ctx context.Context, collectionName string, document interface{}) (interface{}, error)
+ UpdateOne(ctx context.Context, collectionName string, filter interface{}, update interface {}, upsert bool) (interface{}, error)
+ FindOne(ctx context.Context, collectionName string, filter interface{}) (UnorderedDocument, error)
+}
+
+type MongoStorage struct {
client *mongo.Client
collections map[string]*mongo.Collection
}
@@ -22,7 +28,7 @@ type Storage struct {
type OrderedDocument = bson.D
type UnorderedDocument = bson.M
-func NewStorage(uri string, port int, database string) Storage {
+func NewMongoStorage(uri string, port int, database string) *MongoStorage {
opt := options.Client()
opt.ApplyURI(fmt.Sprintf("mongodb://%s:%v", uri, port))
client, err := mongo.NewClient(opt)
@@ -36,13 +42,13 @@ func NewStorage(uri string, port int, database string) Storage {
"connections": db.Collection("connections"),
}
- return Storage{
+ return &MongoStorage{
client: client,
collections: colls,
}
}
-func (storage *Storage) Connect(ctx context.Context) error {
+func (storage *MongoStorage) Connect(ctx context.Context) error {
if ctx == nil {
ctx, _ = context.WithTimeout(context.Background(), defaultConnectionTimeout)
}
@@ -50,7 +56,7 @@ func (storage *Storage) Connect(ctx context.Context) error {
return storage.client.Connect(ctx)
}
-func (storage *Storage) InsertOne(ctx context.Context, collectionName string,
+func (storage *MongoStorage) InsertOne(ctx context.Context, collectionName string,
document interface{}) (interface{}, error) {
collection, ok := storage.collections[collectionName]
@@ -70,7 +76,7 @@ func (storage *Storage) InsertOne(ctx context.Context, collectionName string,
return result.InsertedID, nil
}
-func (storage *Storage) UpdateOne(ctx context.Context, collectionName string,
+func (storage *MongoStorage) UpdateOne(ctx context.Context, collectionName string,
filter interface{}, update interface {}, upsert bool) (interface{}, error) {
collection, ok := storage.collections[collectionName]
@@ -97,7 +103,30 @@ func (storage *Storage) UpdateOne(ctx context.Context, collectionName string,
return result.ModifiedCount == 1, nil
}
-func (storage *Storage) FindOne(ctx context.Context, collectionName string,
+func (storage *MongoStorage) UpdateMany(ctx context.Context, collectionName string,
+ filter interface{}, update interface {}, upsert bool) (interface{}, error) {
+
+ collection, ok := storage.collections[collectionName]
+ if !ok {
+ return nil, errors.New("invalid collection: " + collectionName)
+ }
+
+ if ctx == nil {
+ ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout)
+ }
+
+ opts := options.Update().SetUpsert(upsert)
+ update = bson.D{{"$set", update}}
+
+ result, err := collection.UpdateMany(ctx, filter, update, opts)
+ if err != nil {
+ return nil, err
+ }
+
+ return result.ModifiedCount, nil
+}
+
+func (storage *MongoStorage) FindOne(ctx context.Context, collectionName string,
filter interface{}) (UnorderedDocument, error) {
collection, ok := storage.collections[collectionName]
@@ -121,17 +150,3 @@ func (storage *Storage) FindOne(ctx context.Context, collectionName string,
return result, nil
}
-
-
-func testStorage() {
- storage := NewStorage("localhost", 27017, "testing")
- _ = storage.Connect(nil)
-
- id, err := storage.InsertOne(nil, "connections", bson.M{"_id": "provaaa"})
- if err != nil {
- panic(err)
- } else {
- fmt.Println(id)
- }
-
-}
diff --git a/storage_test.go b/storage_test.go
index 6b36833..b46b60a 100644
--- a/storage_test.go
+++ b/storage_test.go
@@ -1,21 +1,11 @@
package main
import (
- "crypto/sha256"
- "fmt"
- "go.mongodb.org/mongo-driver/mongo"
- "go.mongodb.org/mongo-driver/mongo/options"
- "golang.org/x/net/context"
- "os"
"testing"
- "time"
)
const testCollection = "characters"
-var storage Storage
-var testContext context.Context
-
func testInsert(t *testing.T) {
// insert a document in an invalid connection
insertedId, err := storage.InsertOne(testContext, "invalid_collection",
@@ -98,37 +88,3 @@ func TestBasicOperations(t *testing.T) {
t.Run("testInsert", testInsert)
t.Run("testFindOne", testFindOne)
}
-
-func TestMain(m *testing.M) {
- mongoHost, ok := os.LookupEnv("MONGO_HOST")
- if !ok {
- mongoHost = "localhost"
- }
- mongoPort, ok := os.LookupEnv("MONGO_PORT")
- if !ok {
- mongoPort = "27017"
- }
-
- uniqueDatabaseName := sha256.Sum256([]byte(time.Now().String()))
-
- client, err := mongo.NewClient(options.Client().ApplyURI(fmt.Sprintf("mongodb://%s:%v", mongoHost, mongoPort)))
- if err != nil {
- panic("failed to create mongo client")
- }
-
- db := client.Database(fmt.Sprintf("%x", uniqueDatabaseName[:31]))
- storage = Storage{
- client: client,
- collections: map[string]*mongo.Collection{testCollection: db.Collection(testCollection)},
- }
-
- testContext, _ = context.WithTimeout(context.Background(), 10 * time.Second)
-
- err = storage.Connect(nil)
- if err != nil {
- panic(err)
- }
-
- exitCode := m.Run()
- os.Exit(exitCode)
-}
diff --git a/stream_factory.go b/stream_factory.go
deleted file mode 100644
index e1d76a4..0000000
--- a/stream_factory.go
+++ /dev/null
@@ -1,29 +0,0 @@
-package main
-
-import (
- "github.com/google/gopacket"
- "github.com/google/gopacket/tcpassembly"
-)
-
-type BiDirectionalStreamFactory struct {
- storage *Storage
- serverIp string
-}
-
-// httpStream will handle the actual decoding of http requests.
-type uniDirectionalStream struct {
- net, transport gopacket.Flow
- r StreamHandler
-}
-
-func (h *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream {
- hstream := &uniDirectionalStream{
- net: net,
- transport: transport,
- r: NewStreamHandler(),
- }
- // go hstream.run() // Important... we must guarantee that data from the tcpreader stream is read.
-
- // StreamHandler implements tcpassembly.Stream, so we can return a pointer to it.
- return &hstream.r
-}
diff --git a/stream_handler.go b/stream_handler.go
index ad59856..80d91d6 100644
--- a/stream_handler.go
+++ b/stream_handler.go
@@ -1,206 +1,176 @@
-// Package tcpreader provides an implementation for tcpassembly.Stream which presents
-// the caller with an io.Reader for easy processing.
-//
-// The assembly package handles packet data reordering, but its output is
-// library-specific, thus not usable by the majority of external Go libraries.
-// The io.Reader interface, on the other hand, is used throughout much of Go
-// code as an easy mechanism for reading in data streams and decoding them. For
-// example, the net/http package provides the ReadRequest function, which can
-// parse an HTTP request from a live data stream, just what we'd want when
-// sniffing HTTP traffic. Using StreamHandler, this is relatively easy to set
-// up:
-//
-// // Create our StreamFactory
-// type httpStreamFactory struct {}
-// func (f *httpStreamFactory) New(a, b gopacket.Flow) {
-// r := tcpreader.NewReaderStream(false)
-// go printRequests(r)
-// return &r
-// }
-// func printRequests(r io.Reader) {
-// // Convert to bufio, since that's what ReadRequest wants.
-// buf := bufio.NewReader(r)
-// for {
-// if req, err := http.ReadRequest(buf); err == io.EOF {
-// return
-// } else if err != nil {
-// log.Println("Error parsing HTTP requests:", err)
-// } else {
-// fmt.Println("HTTP REQUEST:", req)
-// fmt.Println("Body contains", tcpreader.DiscardBytesToEOF(req.Body), "bytes")
-// }
-// }
-// }
-//
-// Using just this code, we're able to reference a powerful, built-in library
-// for HTTP request parsing to do all the dirty-work of parsing requests from
-// the wire in real-time. Pass this stream factory to an tcpassembly.StreamPool,
-// start up an tcpassembly.Assembler, and you're good to go!
package main
import (
-"errors"
-"github.com/google/gopacket/tcpassembly"
-"io"
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "github.com/flier/gohs/hyperscan"
+ "github.com/google/gopacket/tcpassembly"
+ "log"
+ "time"
)
-var discardBuffer = make([]byte, 4096)
-
-// DiscardBytesToFirstError will read in all bytes up to the first error
-// reported by the given reader, then return the number of bytes discarded
-// and the error encountered.
-func DiscardBytesToFirstError(r io.Reader) (discarded int, err error) {
- for {
- n, e := r.Read(discardBuffer)
- discarded += n
- if e != nil {
- return discarded, e
- }
- }
-}
+const MaxDocumentSize = 1024 * 1024
+const InitialBlockCount = 1024
+const InitialPatternSliceSize = 8
-// DiscardBytesToEOF will read in all bytes from a Reader until it
-// encounters an io.EOF, then return the number of bytes. Be careful
-// of this... if used on a Reader that returns a non-io.EOF error
-// consistently, this will loop forever discarding that error while
-// it waits for an EOF.
-func DiscardBytesToEOF(r io.Reader) (discarded int) {
- for {
- n, e := DiscardBytesToFirstError(r)
- discarded += n
- if e == io.EOF {
- return
- }
- }
-}
-
-// StreamHandler implements both tcpassembly.Stream and io.Reader. You can use it
-// as a building block to make simple, easy stream handlers.
-//
// IMPORTANT: If you use a StreamHandler, you MUST read ALL BYTES from it,
// quickly. Not reading available bytes will block TCP stream reassembly. It's
// a common pattern to do this by starting a goroutine in the factory's New
// method:
-//
-// type myStreamHandler struct {
-// r StreamHandler
-// }
-// func (m *myStreamHandler) run() {
-// // Do something here that reads all of the StreamHandler, or your assembly
-// // will block.
-// fmt.Println(tcpreader.DiscardBytesToEOF(&m.r))
-// }
-// func (f *myStreamFactory) New(a, b gopacket.Flow) tcpassembly.Stream {
-// s := &myStreamHandler{}
-// go s.run()
-// // Return the StreamHandler as the stream that assembly should populate.
-// return &s.r
-// }
type StreamHandler struct {
- ReaderStreamOptions
- reassembled chan []tcpassembly.Reassembly
- done chan bool
- current []tcpassembly.Reassembly
- closed bool
- lossReported bool
- first bool
- initiated bool
+ connection ConnectionHandler
+ streamKey StreamKey
+ buffer *bytes.Buffer
+ indexes []int
+ timestamps []time.Time
+ lossBlocks []bool
+ currentIndex int
+ firstPacketSeen time.Time
+ lastPacketSeen time.Time
+ documentsKeys []string
+ streamLength int
+ patternStream hyperscan.Stream
+ patternMatches map[uint][]PatternSlice
}
-// ReaderStreamOptions provides user-resettable options for a StreamHandler.
-type ReaderStreamOptions struct {
- // LossErrors determines whether this stream will return
- // ReaderStreamDataLoss errors from its Read function whenever it
- // determines data has been lost.
- LossErrors bool
-}
+type PatternSlice [2]uint64
// NewReaderStream returns a new StreamHandler object.
-func NewStreamHandler() StreamHandler {
- r := StreamHandler{
- reassembled: make(chan []tcpassembly.Reassembly),
- done: make(chan bool),
- first: true,
- initiated: true,
+func NewStreamHandler(connection ConnectionHandler, key StreamKey, scratch *hyperscan.Scratch) StreamHandler {
+ handler := StreamHandler{
+ connection: connection,
+ streamKey: key,
+ buffer: new(bytes.Buffer),
+ indexes: make([]int, 0, InitialBlockCount),
+ timestamps: make([]time.Time, 0, InitialBlockCount),
+ lossBlocks: make([]bool, 0, InitialBlockCount),
+ documentsKeys: make([]string, 0, 1), // most of the time the stream fit in one document
+ patternMatches: make(map[uint][]PatternSlice, 10), // TODO: change with exactly value
}
- return r
+
+ stream, err := connection.Patterns().Open(0, scratch, handler.onMatch, nil)
+ if err != nil {
+ log.Println("failed to create a stream: ", err)
+ }
+ handler.patternStream = stream
+
+ return handler
}
// Reassembled implements tcpassembly.Stream's Reassembled function.
-func (r *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) {
- if !r.initiated {
- panic("StreamHandler not created via NewReaderStream")
+func (sh *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) {
+ for _, r := range reassembly {
+ skip := r.Skip
+ if r.Start {
+ skip = 0
+ sh.firstPacketSeen = r.Seen
+ }
+ if r.End {
+ sh.lastPacketSeen = r.Seen
+ }
+
+ reassemblyLen := len(r.Bytes)
+ if reassemblyLen == 0 {
+ continue
+ }
+
+ if sh.buffer.Len()+len(r.Bytes)-skip > MaxDocumentSize {
+ sh.storageCurrentDocument()
+ sh.resetCurrentDocument()
+ }
+ n, err := sh.buffer.Write(r.Bytes[skip:])
+ if err != nil {
+ log.Println("error while copying bytes from Reassemble in stream_handler")
+ return
+ }
+ sh.indexes = append(sh.indexes, sh.currentIndex)
+ sh.timestamps = append(sh.timestamps, r.Seen)
+ sh.lossBlocks = append(sh.lossBlocks, skip != 0)
+ sh.currentIndex += n
+ sh.streamLength += n
+
+ err = sh.patternStream.Scan(r.Bytes)
+ if err != nil {
+ log.Println("failed to scan packet buffer: ", err)
+ }
}
- r.reassembled <- reassembly
- <-r.done
}
// ReassemblyComplete implements tcpassembly.Stream's ReassemblyComplete function.
-func (r *StreamHandler) ReassemblyComplete() {
- close(r.reassembled)
- close(r.done)
-}
+func (sh *StreamHandler) ReassemblyComplete() {
+ err := sh.patternStream.Close()
+ if err != nil {
+ log.Println("failed to close pattern stream: ", err)
+ }
-// stripEmpty strips empty reassembly slices off the front of its current set of
-// slices.
-func (r *StreamHandler) stripEmpty() {
- for len(r.current) > 0 && len(r.current[0].Bytes) == 0 {
- r.current = r.current[1:]
- r.lossReported = false
+ if sh.currentIndex > 0 {
+ sh.storageCurrentDocument()
}
+ sh.connection.Complete(sh)
}
-// DataLost is returned by the StreamHandler's Read function when it encounters
-// a Reassembly with Skip != 0.
-var DataLost = errors.New("lost data")
-
-// Read implements io.Reader's Read function.
-// Given a byte slice, it will either copy a non-zero number of bytes into
-// that slice and return the number of bytes and a nil error, or it will
-// leave slice p as is and return 0, io.EOF.
-func (r *StreamHandler) Read(p []byte) (int, error) {
- if !r.initiated {
- panic("StreamHandler not created via NewReaderStream")
- }
- var ok bool
- r.stripEmpty()
- for !r.closed && len(r.current) == 0 {
- if r.first {
- r.first = false
- } else {
- r.done <- true
- }
- if r.current, ok = <-r.reassembled; ok {
- r.stripEmpty()
- } else {
- r.closed = true
- }
+func (sh *StreamHandler) resetCurrentDocument() {
+ sh.buffer.Reset()
+ sh.indexes = sh.indexes[:0]
+ sh.timestamps = sh.timestamps[:0]
+ sh.lossBlocks = sh.lossBlocks[:0]
+ sh.currentIndex = 0
+
+ for _, val := range sh.patternMatches {
+ val = val[:0]
}
- if len(r.current) > 0 {
- current := &r.current[0]
- if r.LossErrors && !r.lossReported && current.Skip != 0 {
- r.lossReported = true
- return 0, DataLost
+}
+
+func (sh *StreamHandler) onMatch(id uint, from uint64, to uint64, flags uint, context interface{}) error {
+ patternSlices, isPresent := sh.patternMatches[id]
+ if isPresent {
+ if len(patternSlices) > 0 {
+ lastElement := &patternSlices[len(patternSlices)-1]
+ if lastElement[0] == from { // make the regex greedy to match the maximum number of chars
+ lastElement[1] = to
+ return nil
+ }
}
- length := copy(p, current.Bytes)
- current.Bytes = current.Bytes[length:]
- return length, nil
+ // new from == new match
+ sh.patternMatches[id] = append(patternSlices, PatternSlice{from, to})
+ } else {
+ patternSlices = make([]PatternSlice, InitialPatternSliceSize)
+ patternSlices[0] = PatternSlice{from, to}
+ sh.patternMatches[id] = patternSlices
}
- return 0, io.EOF
+
+ return nil
}
-// Close implements io.Closer's Close function, making StreamHandler a
-// io.ReadCloser. It discards all remaining bytes in the reassembly in a
-// manner that's safe for the assembler (IE: it doesn't block).
-func (r *StreamHandler) Close() error {
- r.current = nil
- r.closed = true
- for {
- if _, ok := <-r.reassembled; !ok {
- return nil
- }
- r.done <- true
+func (sh *StreamHandler) storageCurrentDocument() {
+ streamKey := sh.generateDocumentKey()
+
+ _, err := sh.connection.Storage().InsertOne(sh.connection.Context(), "connection_streams", OrderedDocument{
+ {"_id", streamKey},
+ {"connection_id", nil},
+ {"document_index", len(sh.documentsKeys)},
+ {"payload", sh.buffer.Bytes()},
+ {"blocks_indexes", sh.indexes},
+ {"blocks_timestamps", sh.timestamps},
+ {"blocks_loss", sh.lossBlocks},
+ {"pattern_matches", sh.patternMatches},
+ })
+
+ if err != nil {
+ log.Println("failed to insert connection stream: ", err)
}
+
+ sh.documentsKeys = append(sh.documentsKeys, streamKey)
}
+func (sh *StreamHandler) generateDocumentKey() string {
+ hash := make([]byte, 16)
+ endpointsHash := sh.streamKey[0].FastHash() ^ sh.streamKey[1].FastHash() ^
+ sh.streamKey[2].FastHash() ^ sh.streamKey[3].FastHash()
+ binary.BigEndian.PutUint64(hash, endpointsHash)
+ binary.BigEndian.PutUint64(hash[8:], uint64(sh.timestamps[0].UnixNano()))
+ binary.BigEndian.PutUint16(hash[8:], uint16(len(sh.documentsKeys)))
+ return fmt.Sprintf("%x", hash)
+}
diff --git a/stream_handler_test.go b/stream_handler_test.go
new file mode 100644
index 0000000..a1004dc
--- /dev/null
+++ b/stream_handler_test.go
@@ -0,0 +1,193 @@
+package main
+
+import (
+ "context"
+ "errors"
+ "github.com/flier/gohs/hyperscan"
+ "github.com/google/gopacket/layers"
+ "github.com/google/gopacket/tcpassembly"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "math/rand"
+ "net"
+ "testing"
+ "time"
+)
+
+const testSrcIp = "10.10.10.100"
+const testDstIp = "10.10.10.1"
+const srcPort = 44444
+const dstPort = 8080
+
+
+func TestReassemblingEmptyStream(t *testing.T) {
+ patterns, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0))
+ require.Nil(t, err)
+ streamHandler := createTestStreamHandler(t, testStorage{}, patterns)
+
+ streamHandler.Reassembled([]tcpassembly.Reassembly{})
+ assert.Len(t, streamHandler.indexes, 0, "indexes")
+ assert.Len(t, streamHandler.timestamps, 0, "timestamps")
+ assert.Len(t, streamHandler.lossBlocks, 0)
+ assert.Zero(t, streamHandler.currentIndex)
+ assert.Zero(t, streamHandler.firstPacketSeen)
+ assert.Zero(t, streamHandler.lastPacketSeen)
+ assert.Len(t, streamHandler.documentsKeys, 0)
+ assert.Zero(t, streamHandler.streamLength)
+ assert.Len(t, streamHandler.patternMatches, 0)
+
+ expected := 0
+ streamHandler.connection.(*testConnectionHandler).onComplete = func(handler *StreamHandler) {
+ expected = 42
+ }
+ streamHandler.ReassemblyComplete()
+ assert.Equal(t, 42, expected)
+}
+
+func TestReassemblingSingleDocumentStream(t *testing.T) {
+ patterns, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/impossible_to_match/", 0))
+ require.Nil(t, err)
+ storage := &testStorage{}
+ streamHandler := createTestStreamHandler(t, storage, patterns)
+
+ payloadLen := 256
+ firstTime := time.Unix(1000000000, 0)
+ middleTime := time.Unix(1000000010, 0)
+ lastTime := time.Unix(1000000020, 0)
+ data := make([]byte, MaxDocumentSize)
+ rand.Read(data)
+ reassembles := make([]tcpassembly.Reassembly, MaxDocumentSize / payloadLen)
+ indexes := make([]int, MaxDocumentSize / payloadLen)
+ timestamps := make([]time.Time, MaxDocumentSize / payloadLen)
+ lossBlocks := make([]bool, MaxDocumentSize / payloadLen)
+ for i := 0; i < len(reassembles); i++ {
+ var seen time.Time
+ if i == 0 {
+ seen = firstTime
+ } else if i == len(reassembles)-1 {
+ seen = lastTime
+ } else {
+ seen = middleTime
+ }
+
+ reassembles[i] = tcpassembly.Reassembly{
+ Bytes: data[i*payloadLen:(i+1)*payloadLen],
+ Skip: 0,
+ Start: i == 0,
+ End: i == len(reassembles)-1,
+ Seen: seen,
+ }
+ indexes[i] = i*payloadLen
+ timestamps[i] = seen
+ }
+
+ inserted := false
+ storage.insertFunc = func(ctx context.Context, collectionName string, document interface{}) (i interface{}, err error) {
+ od := document.(OrderedDocument)
+ assert.Equal(t, "connection_streams", collectionName)
+ assert.Equal(t, "bb41a60281cfae830000b6b3a7640000", od[0].Value)
+ assert.Equal(t, nil, od[1].Value)
+ assert.Equal(t, 0, od[2].Value)
+ assert.Equal(t, data, od[3].Value)
+ assert.Equal(t, indexes, od[4].Value)
+ assert.Equal(t, timestamps, od[5].Value)
+ assert.Equal(t, lossBlocks, od[6].Value)
+ assert.Len(t, od[7].Value, 0)
+ inserted = true
+ return nil, nil
+ }
+
+ streamHandler.Reassembled(reassembles)
+ if !assert.Equal(t, false, inserted) {
+ inserted = false
+ }
+
+ assert.Equal(t, data, streamHandler.buffer.Bytes(), "buffer should contains the same bytes of reassembles")
+ assert.Equal(t, indexes, streamHandler.indexes, "indexes")
+ assert.Equal(t, timestamps, streamHandler.timestamps, "timestamps")
+ assert.Equal(t, lossBlocks, streamHandler.lossBlocks, "lossBlocks")
+ assert.Equal(t, len(data), streamHandler.currentIndex)
+ assert.Equal(t, firstTime, streamHandler.firstPacketSeen)
+ assert.Equal(t, lastTime, streamHandler.lastPacketSeen)
+ assert.Len(t, streamHandler.documentsKeys, 0)
+ assert.Equal(t, len(data), streamHandler.streamLength)
+ assert.Len(t, streamHandler.patternMatches, 0)
+
+ completed := false
+ streamHandler.connection.(*testConnectionHandler).onComplete = func(handler *StreamHandler) {
+ completed = true
+ }
+ streamHandler.ReassemblyComplete()
+ assert.Equal(t, true, inserted, "inserted")
+ assert.Equal(t, true, completed, "completed")
+}
+
+
+func createTestStreamHandler(t *testing.T, storage Storage, patterns hyperscan.StreamDatabase) StreamHandler {
+ testConnectionHandler := &testConnectionHandler{
+ storage: storage,
+ context: context.Background(),
+ patterns: patterns,
+ }
+
+ scratch, err := hyperscan.NewScratch(patterns)
+ require.Nil(t, err)
+
+ srcIp := layers.NewIPEndpoint(net.ParseIP(testSrcIp))
+ dstIp := layers.NewIPEndpoint(net.ParseIP(testDstIp))
+ srcPort := layers.NewTCPPortEndpoint(srcPort)
+ dstPort := layers.NewTCPPortEndpoint(dstPort)
+
+ return NewStreamHandler(testConnectionHandler, StreamKey{srcIp, dstIp, srcPort, dstPort}, scratch)
+}
+
+type testConnectionHandler struct {
+ storage Storage
+ context context.Context
+ patterns hyperscan.StreamDatabase
+ onComplete func(*StreamHandler)
+}
+
+func (tch *testConnectionHandler) Storage() Storage {
+ return tch.storage
+}
+
+func (tch *testConnectionHandler) Context() context.Context {
+ return tch.context
+}
+
+func (tch *testConnectionHandler) Patterns() hyperscan.StreamDatabase {
+ return tch.patterns
+}
+
+func (tch *testConnectionHandler) Complete(handler *StreamHandler) {
+ tch.onComplete(handler)
+}
+
+type testStorage struct {
+ insertFunc func(ctx context.Context, collectionName string, document interface{}) (interface{}, error)
+ updateOne func(ctx context.Context, collectionName string, filter interface{}, update interface {}, upsert bool) (interface{}, error)
+ findOne func(ctx context.Context, collectionName string, filter interface{}) (UnorderedDocument, error)
+}
+
+func (ts testStorage) InsertOne(ctx context.Context, collectionName string, document interface{}) (interface{}, error) {
+ if ts.insertFunc != nil {
+ return ts.insertFunc(ctx, collectionName, document)
+ }
+ return nil, errors.New("not implemented")
+}
+
+func (ts testStorage) UpdateOne(ctx context.Context, collectionName string, filter interface{}, update interface {},
+ upsert bool) (interface{}, error) {
+ if ts.updateOne != nil {
+ return ts.updateOne(ctx, collectionName, filter, update, upsert)
+ }
+ return nil, errors.New("not implemented")
+}
+
+func (ts testStorage) FindOne(ctx context.Context, collectionName string, filter interface{}) (UnorderedDocument, error) {
+ if ts.insertFunc != nil {
+ return ts.findOne(ctx, collectionName, filter)
+ }
+ return nil, errors.New("not implemented")
+}