aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--caronte_test.go60
-rw-r--r--connection_handler.go14
-rw-r--r--connection_streams.go16
-rw-r--r--go.mod1
-rw-r--r--go.sum1
-rw-r--r--pcap_importer.go22
-rw-r--r--routes.go3
-rw-r--r--rules_manager.go132
-rw-r--r--storage.go339
-rw-r--r--storage_test.go300
-rw-r--r--stream_handler.go71
-rw-r--r--stream_handler_test.go219
12 files changed, 628 insertions, 550 deletions
diff --git a/caronte_test.go b/caronte_test.go
index 9942086..2766640 100644
--- a/caronte_test.go
+++ b/caronte_test.go
@@ -4,21 +4,21 @@ import (
"context"
"crypto/sha256"
"fmt"
- "go.mongodb.org/mongo-driver/mongo"
- "go.mongodb.org/mongo-driver/mongo/options"
- "log"
+ log "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/require"
"os"
+ "strconv"
"testing"
"time"
)
-var storage Storage
-var testContext context.Context
-
-const testInsertManyFindCollection = "testFi"
-const testCollection = "characters"
+type TestStorageWrapper struct {
+ DbName string
+ Storage *MongoStorage
+ Context context.Context
+}
-func TestMain(m *testing.M) {
+func NewTestStorageWrapper(t *testing.T) *TestStorageWrapper {
mongoHost, ok := os.LookupEnv("MONGO_HOST")
if !ok {
mongoHost = "localhost"
@@ -27,33 +27,31 @@ func TestMain(m *testing.M) {
if !ok {
mongoPort = "27017"
}
+ port, err := strconv.Atoi(mongoPort)
+ require.NoError(t, err, "invalid port")
uniqueDatabaseName := sha256.Sum256([]byte(time.Now().String()))
-
- client, err := mongo.NewClient(options.Client().ApplyURI(fmt.Sprintf("mongodb://%s:%v", mongoHost, mongoPort)))
- if err != nil {
- panic("failed to create mongo client")
- }
-
dbName := fmt.Sprintf("%x", uniqueDatabaseName[:31])
- db := client.Database(dbName)
- log.Println("using database", dbName)
- mongoStorage := MongoStorage{
- client: client,
- collections: map[string]*mongo.Collection{
- testInsertManyFindCollection: db.Collection(testInsertManyFindCollection),
- testCollection: db.Collection(testCollection),
- },
- }
+ log.WithField("dbName", dbName).Info("creating new storage")
+
+ storage := NewMongoStorage(mongoHost, port, dbName)
+ ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
- testContext, _ = context.WithTimeout(context.Background(), 10 * time.Second)
+ err = storage.Connect(ctx)
+ require.NoError(t, err, "failed to connect to database")
- err = mongoStorage.Connect(testContext)
- if err != nil {
- panic(err)
+ return &TestStorageWrapper{
+ DbName: dbName,
+ Storage: storage,
+ Context: ctx,
}
- storage = &mongoStorage
+}
+
+func (tsw TestStorageWrapper) AddCollection(collectionName string) {
+ tsw.Storage.collections[collectionName] = tsw.Storage.client.Database(tsw.DbName).Collection(collectionName)
+}
- exitCode := m.Run()
- os.Exit(exitCode)
+func (tsw TestStorageWrapper) Destroy(t *testing.T) {
+ err := tsw.Storage.client.Disconnect(tsw.Context)
+ require.NoError(t, err, "failed to disconnect to database")
}
diff --git a/connection_handler.go b/connection_handler.go
index 36dca6d..2e2fa84 100644
--- a/connection_handler.go
+++ b/connection_handler.go
@@ -64,7 +64,7 @@ func (factory *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcp
initiator: initiator,
mComplete: sync.Mutex{},
context: context.Background(),
- patterns : factory.patterns,
+ patterns: factory.patterns,
}
factory.connections[key] = connection
}
@@ -122,7 +122,7 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) {
ch.generateConnectionKey(startedAt)
- _, err := ch.storage.InsertOne(ch.context, "connections", OrderedDocument{
+ _, err := ch.storage.Insert("connections").Context(ch.context).One(OrderedDocument{
{"_id", ch.connectionKey},
{"ip_src", ch.initiator[0].String()},
{"ip_dst", ch.initiator[1].String()},
@@ -141,14 +141,14 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) {
}
streamsIds := append(client.documentsKeys, server.documentsKeys...)
- n, err := ch.storage.UpdateOne(ch.context, "connection_streams",
- UnorderedDocument{"_id": UnorderedDocument{"$in": streamsIds}},
- UnorderedDocument{"connection_id": ch.connectionKey},
- false)
+ n, err := ch.storage.Update("connection_streams").
+ Context(ch.context).
+ Filter(OrderedDocument{{"_id", UnorderedDocument{"$in": streamsIds}}}).
+ Many(UnorderedDocument{"connection_id": ch.connectionKey})
if err != nil {
log.Println("failed to update connection streams", err)
}
- if n != len(streamsIds) {
+ if int(n) != len(streamsIds) {
log.Println("failed to update all connections streams")
}
}
diff --git a/connection_streams.go b/connection_streams.go
new file mode 100644
index 0000000..bede526
--- /dev/null
+++ b/connection_streams.go
@@ -0,0 +1,16 @@
+package main
+
+import "time"
+
+type ConnectionStream struct {
+ ID RowID `json:"id" bson:"_id"`
+ ConnectionID RowID `json:"connection_id" bson:"connection_id"`
+ DocumentIndex int `json:"document_index" bson:"document_index"`
+ Payload []byte `json:"payload" bson:"payload"`
+ BlocksIndexes []int `json:"blocks_indexes" bson:"blocks_indexes"`
+ BlocksTimestamps []time.Time `json:"blocks_timestamps" bson:"blocks_timestamps"`
+ BlocksLoss []bool `json:"blocks_loss" bson:"blocks_loss"`
+ PatternMatches map[uint][]PatternSlice `json:"pattern_matches" bson:"pattern_matches"`
+}
+
+type PatternSlice [2]uint64
diff --git a/go.mod b/go.mod
index 99863d7..ce12c15 100644
--- a/go.mod
+++ b/go.mod
@@ -10,6 +10,7 @@ require (
github.com/google/gopacket v1.1.17
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.1 // indirect
+ github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.4.0
go.mongodb.org/mongo-driver v1.3.1
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3
diff --git a/go.sum b/go.sum
index 16a4ce9..d2863ce 100644
--- a/go.sum
+++ b/go.sum
@@ -103,6 +103,7 @@ github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
+github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/smartystreets/assertions v0.0.0-20180820201707-7c9eb446e3cf h1:6V1qxN6Usn4jy8unvggSJz/NC790tefw8Zdy6OZS5co=
github.com/smartystreets/assertions v0.0.0-20180820201707-7c9eb446e3cf/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
diff --git a/pcap_importer.go b/pcap_importer.go
index c6260de..ac7c0b5 100644
--- a/pcap_importer.go
+++ b/pcap_importer.go
@@ -22,7 +22,6 @@ const importUpdateProgressInterval = 3 * time.Second
const initialPacketPerServicesMapSize = 16
const importedPcapsCollectionName = "imported_pcaps"
-
type PcapImporter struct {
storage Storage
streamPool *tcpassembly.StreamPool
@@ -35,11 +34,10 @@ type PcapImporter struct {
type flowCount [2]int
-
func NewPcapImporter(storage Storage, serverIp net.IP) *PcapImporter {
serverEndpoint := layers.NewIPEndpoint(serverIp)
streamFactory := &BiDirectionalStreamFactory{
- storage: storage,
+ storage: storage,
serverIp: serverEndpoint,
}
streamPool := tcpassembly.NewStreamPool(streamFactory)
@@ -82,7 +80,7 @@ func (pi *PcapImporter) ImportPcap(fileName string) (string, error) {
{"importing_error", err},
}
ctx, canc := context.WithCancel(context.Background())
- _, err = pi.storage.InsertOne(ctx, importedPcapsCollectionName, doc)
+ _, err = pi.storage.Insert(importedPcapsCollectionName).Context(ctx).One(doc)
if err != nil {
pi.mSessions.Unlock()
_, alreadyProcessed := err.(mongo.WriteException)
@@ -133,18 +131,18 @@ func (pi *PcapImporter) parsePcap(sessionId, fileName string, ctx context.Contex
progressUpdate := func(completed bool, err error) {
update := UnorderedDocument{
- "processed_packets": processedPackets,
- "invalid_packets": invalidPackets,
+ "processed_packets": processedPackets,
+ "invalid_packets": invalidPackets,
"packets_per_services": packetsPerService,
- "importing_error": err,
+ "importing_error": err,
}
if completed {
update["completed_at"] = time.Now()
}
- _, _err := pi.storage.UpdateOne(nil, importedPcapsCollectionName, OrderedDocument{{"_id", sessionId}},
- completed, false)
-
+ _, _err := pi.storage.Update(importedPcapsCollectionName).
+ Filter(OrderedDocument{{"_id", sessionId}}).
+ One(nil)
if _err != nil {
log.Println("can't update importing statistics : ", _err)
}
@@ -158,10 +156,10 @@ func (pi *PcapImporter) parsePcap(sessionId, fileName string, ctx context.Contex
for {
select {
- case <- ctx.Done():
+ case <-ctx.Done():
handle.Close()
deleteSession()
- progressUpdate(false, errors.New("import process cancelled"))
+ progressUpdate(false, errors.New("import process cancelled"))
return
default:
}
diff --git a/routes.go b/routes.go
index f44cff7..3759382 100644
--- a/routes.go
+++ b/routes.go
@@ -4,7 +4,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
- "log"
+ log "github.com/sirupsen/logrus"
"net/http"
)
@@ -23,6 +23,7 @@ func ApplicationRoutes(engine *gin.Engine) {
"error": fmt.Sprintf("field '%v' does not respect the %v(%v) rule",
fieldErr.Field(), fieldErr.Tag(), fieldErr.Param()),
})
+ log.WithError(err).WithField("rule", rule).Panic("oops")
return // exit on first error
}
}
diff --git a/rules_manager.go b/rules_manager.go
index d6e9aaa..e5a8d38 100644
--- a/rules_manager.go
+++ b/rules_manager.go
@@ -3,56 +3,53 @@ package main
import (
"context"
"crypto/sha256"
- "encoding/json"
"errors"
"fmt"
"github.com/flier/gohs/hyperscan"
- "go.mongodb.org/mongo-driver/bson/primitive"
- "log"
+ log "github.com/sirupsen/logrus"
"sync"
"time"
)
-
type RegexFlags struct {
- Caseless bool `json:"caseless"` // Set case-insensitive matching.
- DotAll bool `json:"dot_all"` // Matching a `.` will not exclude newlines.
- MultiLine bool `json:"multi_line"` // Set multi-line anchoring.
- SingleMatch bool `json:"single_match"` // Set single-match only mode.
- Utf8Mode bool `json:"utf_8_mode"` // Enable UTF-8 mode for this expression.
- UnicodeProperty bool `json:"unicode_property"` // Enable Unicode property support for this expression
+ Caseless bool `json:"caseless"` // Set case-insensitive matching.
+ DotAll bool `json:"dot_all"` // Matching a `.` will not exclude newlines.
+ MultiLine bool `json:"multi_line"` // Set multi-line anchoring.
+ SingleMatch bool `json:"single_match"` // Set single-match only mode.
+ Utf8Mode bool `json:"utf_8_mode"` // Enable UTF-8 mode for this expression.
+ UnicodeProperty bool `json:"unicode_property"` // Enable Unicode property support for this expression
}
type Pattern struct {
- Regex string `json:"regex"`
- Flags RegexFlags `json:"flags"`
- MinOccurrences int `json:"min_occurrences"`
- MaxOccurrences int `json:"max_occurrences"`
- internalId int
+ Regex string `json:"regex"`
+ Flags RegexFlags `json:"flags"`
+ MinOccurrences int `json:"min_occurrences"`
+ MaxOccurrences int `json:"max_occurrences"`
+ internalId int
compiledPattern *hyperscan.Pattern
}
type Filter struct {
- ServicePort int
+ ServicePort int
ClientAddress string
- ClientPort int
- MinDuration int
- MaxDuration int
- MinPackets int
- MaxPackets int
- MinSize int
- MaxSize int
+ ClientPort int
+ MinDuration int
+ MaxDuration int
+ MinPackets int
+ MaxPackets int
+ MinSize int
+ MaxSize int
}
type Rule struct {
- Id string `json:"-" bson:"_id,omitempty"`
- Name string `json:"name" binding:"required,min=3" bson:"name"`
- Color string `json:"color" binding:"required,hexcolor" bson:"color"`
- Notes string `json:"notes" bson:"notes,omitempty"`
- Enabled bool `json:"enabled" bson:"enabled"`
+ Id RowID `json:"-" bson:"_id,omitempty"`
+ Name string `json:"name" binding:"required,min=3" bson:"name"`
+ Color string `json:"color" binding:"required,hexcolor" bson:"color"`
+ Notes string `json:"notes" bson:"notes,omitempty"`
+ Enabled bool `json:"enabled" bson:"enabled"`
Patterns []Pattern `json:"patterns" binding:"required,min=1" bson:"patterns"`
- Filter Filter `json:"filter" bson:"filter,omitempty"`
- Version int64 `json:"version" bson:"version"`
+ Filter Filter `json:"filter" bson:"filter,omitempty"`
+ Version int64 `json:"version" bson:"version"`
}
type RulesManager struct {
@@ -67,57 +64,53 @@ type RulesManager struct {
func NewRulesManager(storage Storage) RulesManager {
return RulesManager{
- storage: storage,
- rules: make(map[string]Rule),
- patterns: make(map[string]Pattern),
- mPatterns: sync.Mutex{},
+ storage: storage,
+ rules: make(map[string]Rule),
+ patterns: make(map[string]Pattern),
+ mPatterns: sync.Mutex{},
}
}
-
-
func (rm RulesManager) LoadRules() error {
var rules []Rule
- if err := rm.storage.Find(nil, Rules, NoFilters, &rules); err != nil {
+ if err := rm.storage.Find(Rules).Sort("_id", true).All(&rules); err != nil {
return err
}
- var version int64
for _, rule := range rules {
if err := rm.validateAndAddRuleLocal(&rule); err != nil {
- log.Printf("failed to import rule %s: %s\n", rule.Name, err)
- continue
- }
- if rule.Version > version {
- version = rule.Version
+ log.WithError(err).WithField("rule", rule).Warn("failed to import rule")
}
}
rm.ruleIndex = len(rules)
- return rm.generateDatabase(0)
+ return rm.generateDatabase(rules[len(rules)-1].Id)
}
func (rm RulesManager) AddRule(context context.Context, rule Rule) (string, error) {
rm.mPatterns.Lock()
- rule.Id = UniqueKey(time.Now(), uint32(rm.ruleIndex))
+ rule.Id = rm.storage.NewCustomRowID(uint64(rm.ruleIndex), time.Now())
rule.Enabled = true
if err := rm.validateAndAddRuleLocal(&rule); err != nil {
rm.mPatterns.Unlock()
return "", err
}
+
+ if err := rm.generateDatabase(rule.Id); err != nil {
+ rm.mPatterns.Unlock()
+ log.WithError(err).WithField("rule", rule).Panic("failed to generate database")
+ }
rm.mPatterns.Unlock()
- if _, err := rm.storage.InsertOne(context, Rules, rule); err != nil {
- return "", err
+ if _, err := rm.storage.Insert(Rules).Context(context).One(rule); err != nil {
+ log.WithError(err).WithField("rule", rule).Panic("failed to insert rule on database")
}
- return rule.Id, rm.generateDatabase(rule.Id)
+ return rule.Id.Hex(), nil
}
-
-
func (rm RulesManager) validateAndAddRuleLocal(rule *Rule) error {
if _, alreadyPresent := rm.rulesByName[rule.Name]; alreadyPresent {
return errors.New("rule name must be unique")
@@ -142,16 +135,18 @@ func (rm RulesManager) validateAndAddRuleLocal(rule *Rule) error {
rm.patterns[key] = value
}
- rm.rules[rule.Id] = *rule
+ rm.rules[rule.Id.Hex()] = *rule
rm.rulesByName[rule.Name] = *rule
return nil
}
-func (rm RulesManager) generateDatabase(version string) error {
+func (rm RulesManager) generateDatabase(version RowID) error {
patterns := make([]*hyperscan.Pattern, len(rm.patterns))
+ var i int
for _, pattern := range rm.patterns {
- patterns = append(patterns, pattern.compiledPattern)
+ patterns[i] = pattern.compiledPattern
+ i++
}
database, err := hyperscan.NewStreamDatabase(patterns...)
if err != nil {
@@ -162,7 +157,6 @@ func (rm RulesManager) generateDatabase(version string) error {
return nil
}
-
func (p Pattern) BuildPattern() error {
if p.compiledPattern != nil {
return nil
@@ -210,33 +204,3 @@ func (p Pattern) Hash() string {
hash.Write([]byte(fmt.Sprintf("%s|%v|%v|%v", p.Regex, p.Flags, p.MinOccurrences, p.MaxOccurrences)))
return fmt.Sprintf("%x", hash.Sum(nil))
}
-
-func test() {
- user := &Pattern{Regex: "Frank"}
- b, err := json.Marshal(user)
- if err != nil {
- fmt.Println(err)
- return
- }
- fmt.Println(string(b))
-
- p, _ := hyperscan.ParsePattern("/a/")
- p1, _ := hyperscan.ParsePattern("/a/")
- fmt.Println(p1.String(), p1.Flags)
- //p1.Id = 1
-
- fmt.Println(*p == *p1)
- db, _ := hyperscan.NewBlockDatabase(p, p1)
- s, _ := hyperscan.NewScratch(db)
- db.Scan([]byte("Ciao"), s, onMatch, nil)
-
-
-
-
-}
-
-func onMatch(id uint, from uint64, to uint64, flags uint, context interface{}) error {
- fmt.Println(id)
-
- return nil
-} \ No newline at end of file
diff --git a/storage.go b/storage.go
index 3be56d7..b88c5c8 100644
--- a/storage.go
+++ b/storage.go
@@ -6,60 +6,55 @@ import (
"encoding/hex"
"errors"
"fmt"
- "time"
-
+ log "github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson"
+ "go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
+ "time"
)
+// Collections names
const Connections = "connections"
+const ConnectionStreams = "connection_streams"
const ImportedPcaps = "imported_pcaps"
const Rules = "rules"
-var NoFilters = UnorderedDocument{}
+const defaultConnectionTimeout = 10 * time.Second
-const defaultConnectionTimeout = 10*time.Second
-const defaultOperationTimeout = 3*time.Second
+var ZeroRowID [12]byte
type Storage interface {
- InsertOne(ctx context.Context, collectionName string, document interface{}) (interface{}, error)
- InsertMany(ctx context.Context, collectionName string, documents []interface{}) ([]interface{}, error)
- UpdateOne(ctx context.Context, collectionName string, filter interface{}, update interface {}, upsert bool) (interface{}, error)
- UpdateMany(ctx context.Context, collectionName string, filter interface{}, update interface {}, upsert bool) (interface{}, error)
- FindOne(ctx context.Context, collectionName string, filter interface{}) (UnorderedDocument, error)
- Find(ctx context.Context, collectionName string, filter interface{}, results interface{}) error
+ Insert(collectionName string) InsertOperation
+ Update(collectionName string) UpdateOperation
+ Find(collectionName string) FindOperation
+ NewCustomRowID(payload uint64, timestamp time.Time) RowID
+ NewRowID() RowID
}
type MongoStorage struct {
- client *mongo.Client
+ client *mongo.Client
collections map[string]*mongo.Collection
}
type OrderedDocument = bson.D
type UnorderedDocument = bson.M
-
-func UniqueKey(timestamp time.Time, payload uint32) string {
- var key [8]byte
- binary.BigEndian.PutUint32(key[0:4], uint32(timestamp.Unix()))
- binary.BigEndian.PutUint32(key[4:8], payload)
-
- return hex.EncodeToString(key[:])
-}
+type Entry = bson.E
+type RowID = primitive.ObjectID
func NewMongoStorage(uri string, port int, database string) *MongoStorage {
opt := options.Client()
opt.ApplyURI(fmt.Sprintf("mongodb://%s:%v", uri, port))
client, err := mongo.NewClient(opt)
if err != nil {
- panic("Failed to create mongo client")
+ log.WithError(err).Panic("failed to create mongo client")
}
db := client.Database(database)
colls := map[string]*mongo.Collection{
- Connections: db.Collection(Connections),
+ Connections: db.Collection(Connections),
ImportedPcaps: db.Collection(ImportedPcaps),
- Rules: db.Collection(Rules),
+ Rules: db.Collection(Rules),
}
return &MongoStorage{
@@ -76,19 +71,55 @@ func (storage *MongoStorage) Connect(ctx context.Context) error {
return storage.client.Connect(ctx)
}
-func (storage *MongoStorage) InsertOne(ctx context.Context, collectionName string,
- document interface{}) (interface{}, error) {
+func (storage *MongoStorage) NewCustomRowID(payload uint64, timestamp time.Time) RowID {
+ var key [12]byte
+ binary.BigEndian.PutUint32(key[0:4], uint32(timestamp.Unix()))
+ binary.BigEndian.PutUint64(key[4:12], payload)
- collection, ok := storage.collections[collectionName]
- if !ok {
- return nil, errors.New("invalid collection: " + collectionName)
+ if oid, err := primitive.ObjectIDFromHex(hex.EncodeToString(key[:])); err == nil {
+ return oid
+ } else {
+ log.WithError(err).Warn("failed to create object id")
+ return primitive.NewObjectID()
}
+}
- if ctx == nil {
- ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout)
+func (storage *MongoStorage) NewRowID() RowID {
+ return primitive.NewObjectID()
+}
+
+// InsertOne and InsertMany
+
+type InsertOperation interface {
+ Context(ctx context.Context) InsertOperation
+ StopOnFail(stop bool) InsertOperation
+ One(document interface{}) (interface{}, error)
+ Many(documents []interface{}) ([]interface{}, error)
+}
+
+type MongoInsertOperation struct {
+ collection *mongo.Collection
+ ctx context.Context
+ optInsertMany *options.InsertManyOptions
+ err error
+}
+
+func (fo MongoInsertOperation) Context(ctx context.Context) InsertOperation {
+ fo.ctx = ctx
+ return fo
+}
+
+func (fo MongoInsertOperation) StopOnFail(stop bool) InsertOperation {
+ fo.optInsertMany.SetOrdered(stop)
+ return fo
+}
+
+func (fo MongoInsertOperation) One(document interface{}) (interface{}, error) {
+ if fo.err != nil {
+ return nil, fo.err
}
- result, err := collection.InsertOne(ctx, document)
+ result, err := fo.collection.InsertOne(fo.ctx, document)
if err != nil {
return nil, err
}
@@ -96,149 +127,211 @@ func (storage *MongoStorage) InsertOne(ctx context.Context, collectionName strin
return result.InsertedID, nil
}
-func (storage *MongoStorage) InsertMany(ctx context.Context, collectionName string,
- documents []interface{}) ([]interface{}, error) {
-
- collection, ok := storage.collections[collectionName]
- if !ok {
- return nil, errors.New("invalid collection: " + collectionName)
- }
-
- if ctx == nil {
- ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout)
+func (fo MongoInsertOperation) Many(documents []interface{}) ([]interface{}, error) {
+ if fo.err != nil {
+ return nil, fo.err
}
- result, err := collection.InsertMany(ctx, documents)
+ results, err := fo.collection.InsertMany(fo.ctx, documents, fo.optInsertMany)
if err != nil {
return nil, err
}
- return result.InsertedIDs, nil
+ return results.InsertedIDs, nil
}
-func (storage *MongoStorage) UpdateOne(ctx context.Context, collectionName string,
- filter interface{}, update interface {}, upsert bool) (interface{}, error) {
-
+func (storage *MongoStorage) Insert(collectionName string) InsertOperation {
collection, ok := storage.collections[collectionName]
+ op := MongoInsertOperation{
+ collection: collection,
+ optInsertMany: options.InsertMany(),
+ }
if !ok {
- return nil, errors.New("invalid collection: " + collectionName)
+ op.err = errors.New("invalid collection: " + collectionName)
}
+ return op
+}
- if ctx == nil {
- ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout)
- }
+// UpdateOne and UpdateMany
- opts := options.Update().SetUpsert(upsert)
- update = bson.D{{"$set", update}}
+type UpdateOperation interface {
+ Context(ctx context.Context) UpdateOperation
+ Filter(filter OrderedDocument) UpdateOperation
+ Upsert(upsertResults *interface{}) UpdateOperation
+ One(update interface{}) (bool, error)
+ Many(update interface{}) (int64, error)
+}
- result, err := collection.UpdateOne(ctx, filter, update, opts)
- if err != nil {
- return nil, err
- }
+type MongoUpdateOperation struct {
+ collection *mongo.Collection
+ filter OrderedDocument
+ update OrderedDocument
+ ctx context.Context
+ opt *options.UpdateOptions
+ upsertResult *interface{}
+ err error
+}
- if upsert {
- return result.UpsertedID, nil
- }
+func (fo MongoUpdateOperation) Context(ctx context.Context) UpdateOperation {
+ fo.ctx = ctx
+ return fo
+}
- return result.ModifiedCount == 1, nil
+func (fo MongoUpdateOperation) Filter(filter OrderedDocument) UpdateOperation {
+ fo.filter = filter
+ return fo
}
-func (storage *MongoStorage) UpdateMany(ctx context.Context, collectionName string,
- filter interface{}, update interface {}, upsert bool) (interface{}, error) {
+func (fo MongoUpdateOperation) Upsert(upsertResults *interface{}) UpdateOperation {
+ fo.upsertResult = upsertResults
+ fo.opt.SetUpsert(true)
+ return fo
+}
- collection, ok := storage.collections[collectionName]
- if !ok {
- return nil, errors.New("invalid collection: " + collectionName)
+func (fo MongoUpdateOperation) One(update interface{}) (bool, error) {
+ if fo.err != nil {
+ return false, fo.err
}
- if ctx == nil {
- ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout)
+ for i := range fo.update {
+ fo.update[i].Value = update
+ }
+ result, err := fo.collection.UpdateOne(fo.ctx, fo.filter, fo.update, fo.opt)
+ if err != nil {
+ return false, err
}
- opts := options.Update().SetUpsert(upsert)
- update = bson.D{{"$set", update}}
+ if fo.upsertResult != nil {
+ *(fo.upsertResult) = result.UpsertedID
+ }
+ return result.ModifiedCount == 1, nil
+}
+
+func (fo MongoUpdateOperation) Many(update interface{}) (int64, error) {
+ if fo.err != nil {
+ return 0, fo.err
+ }
- result, err := collection.UpdateMany(ctx, filter, update, opts)
+ for i := range fo.update {
+ fo.update[i].Value = update
+ }
+ result, err := fo.collection.UpdateMany(fo.ctx, fo.filter, fo.update, fo.opt)
if err != nil {
- return nil, err
+ return 0, err
}
+ if fo.upsertResult != nil {
+ *(fo.upsertResult) = result.UpsertedID
+ }
return result.ModifiedCount, nil
}
-func (storage *MongoStorage) FindOne(ctx context.Context, collectionName string,
- filter interface{}) (UnorderedDocument, error) {
-
+func (storage *MongoStorage) Update(collectionName string) UpdateOperation {
collection, ok := storage.collections[collectionName]
+ op := MongoUpdateOperation{
+ collection: collection,
+ filter: OrderedDocument{},
+ update: OrderedDocument{{"$set", nil}},
+ opt: options.Update(),
+ }
if !ok {
- return nil, errors.New("invalid collection: " + collectionName)
+ op.err = errors.New("invalid collection: " + collectionName)
}
+ return op
+}
- if ctx == nil {
- ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout)
- }
+// Find and FindOne
- var result bson.M
- err := collection.FindOne(ctx, filter).Decode(&result)
- if err != nil {
- if err == mongo.ErrNoDocuments {
- return nil, nil
- }
+type FindOperation interface {
+ Context(ctx context.Context) FindOperation
+ Filter(filter OrderedDocument) FindOperation
+ Sort(field string, ascending bool) FindOperation
+ Limit(n int64) FindOperation
+ First(result interface{}) error
+ All(results interface{}) error
+}
- return nil, err
- }
+type MongoFindOperation struct {
+ collection *mongo.Collection
+ filter OrderedDocument
+ ctx context.Context
+ optFind *options.FindOptions
+ optFindOne *options.FindOneOptions
+ sorts []Entry
+ err error
+}
- return result, nil
+func (fo MongoFindOperation) Context(ctx context.Context) FindOperation {
+ fo.ctx = ctx
+ return fo
}
-type FindOperation struct {
- options options.FindOptions
+func (fo MongoFindOperation) Filter(filter OrderedDocument) FindOperation {
+ fo.filter = filter
+ return fo
}
+func (fo MongoFindOperation) Limit(n int64) FindOperation {
+ fo.optFind.SetLimit(n)
+ return fo
+}
+func (fo MongoFindOperation) Sort(field string, ascending bool) FindOperation {
+ var sort int
+ if ascending {
+ sort = 1
+ } else {
+ sort = -1
+ }
+ fo.sorts = append(fo.sorts, primitive.E{Key: field, Value: sort})
+ fo.optFind.SetSort(fo.sorts)
+ fo.optFindOne.SetSort(fo.sorts)
+ return fo
+}
-func (storage *MongoStorage) Find(ctx context.Context, collectionName string,
- filter interface{}, results interface{}) error {
+func (fo MongoFindOperation) First(result interface{}) error {
+ if fo.err != nil {
+ return fo.err
+ }
- collection, ok := storage.collections[collectionName]
- if !ok {
- return errors.New("invalid collection: " + collectionName)
+ err := fo.collection.FindOne(fo.ctx, fo.filter, fo.optFindOne).Decode(result)
+ if err != nil {
+ if err == mongo.ErrNoDocuments {
+ result = nil
+ return nil
+ }
+
+ return err
}
+ return nil
+}
- if ctx == nil {
- ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout)
- }
-
- options.FindOptions{
- AllowDiskUse: nil,
- AllowPartialResults: nil,
- BatchSize: nil,
- Collation: nil,
- Comment: nil,
- CursorType: nil,
- Hint: nil,
- Limit: nil,
- Max: nil,
- MaxAwaitTime: nil,
- MaxTime: nil,
- Min: nil,
- NoCursorTimeout: nil,
- OplogReplay: nil,
- Projection: nil,
- ReturnKey: nil,
- ShowRecordID: nil,
- Skip: nil,
- Snapshot: nil,
- Sort: nil,
- }
- cursor, err := collection.Find(ctx, filter)
+func (fo MongoFindOperation) All(results interface{}) error {
+ if fo.err != nil {
+ return fo.err
+ }
+ cursor, err := fo.collection.Find(fo.ctx, fo.filter, fo.optFind)
if err != nil {
return err
}
- err = cursor.All(ctx, results)
+ err = cursor.All(fo.ctx, results)
if err != nil {
return err
}
-
return nil
}
+
+func (storage *MongoStorage) Find(collectionName string) FindOperation {
+ collection, ok := storage.collections[collectionName]
+ op := MongoFindOperation{
+ collection: collection,
+ filter: OrderedDocument{},
+ optFind: options.Find(),
+ optFindOne: options.FindOne(),
+ sorts: OrderedDocument{},
+ }
+ if !ok {
+ op.err = errors.New("invalid collection: " + collectionName)
+ }
+ return op
+}
diff --git a/storage_test.go b/storage_test.go
index 40440a4..e34bdb3 100644
--- a/storage_test.go
+++ b/storage_test.go
@@ -1,8 +1,6 @@
package main
import (
- "context"
- "errors"
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/bson/primitive"
"testing"
@@ -11,102 +9,177 @@ import (
type a struct {
Id primitive.ObjectID `bson:"_id,omitempty"`
- A string `bson:"a,omitempty"`
- B int `bson:"b,omitempty"`
- C time.Time `bson:"c,omitempty"`
- D map[string]b `bson:"d"`
- E []b `bson:"e,omitempty"`
+ A string `bson:"a,omitempty"`
+ B int `bson:"b,omitempty"`
+ C time.Time `bson:"c,omitempty"`
+ D map[string]b `bson:"d"`
+ E []b `bson:"e,omitempty"`
}
type b struct {
A string `bson:"a,omitempty"`
- B int `bson:"b,omitempty"`
+ B int `bson:"b,omitempty"`
}
-func testInsert(t *testing.T) {
- // insert a document in an invalid connection
- insertedId, err := storage.InsertOne(testContext, "invalid_collection",
- OrderedDocument{{"key", "invalid"}})
- if insertedId != nil || err == nil {
- t.Fatal("inserting documents in invalid collections must fail")
- }
+func TestOperationOnInvalidCollection(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
- // insert ordered document
- beatriceId, err := storage.InsertOne(testContext, testCollection,
- OrderedDocument{{"name", "Beatrice"}, {"description", "blablabla"}})
- if err != nil {
- t.Fatal(err)
- }
- if beatriceId == nil {
- t.Fatal("failed to insert an ordered document")
- }
+ simpleDoc := UnorderedDocument{"key": "a", "value": 0}
+ insertOp := wrapper.Storage.Insert("invalid_collection").Context(wrapper.Context)
+ insertedId, err := insertOp.One(simpleDoc)
+ assert.Nil(t, insertedId)
+ assert.Error(t, err)
- // insert unordered document
- virgilioId, err := storage.InsertOne(testContext, testCollection,
- UnorderedDocument{"name": "Virgilio", "description": "blablabla"})
- if err != nil {
- t.Fatal(err)
- }
- if virgilioId == nil {
- t.Fatal("failed to insert an unordered document")
- }
+ insertedIds, err := insertOp.Many([]interface{}{simpleDoc})
+ assert.Nil(t, insertedIds)
+ assert.Error(t, err)
- // insert document with custom id
- danteId := "000000"
- insertedId, err = storage.InsertOne(testContext, testCollection,
- UnorderedDocument{"_id": danteId, "name": "Dante Alighieri", "description": "blablabla"})
- if err != nil {
- t.Fatal(err)
- }
- if insertedId != danteId {
- t.Fatal("returned id doesn't match")
- }
+ updateOp := wrapper.Storage.Update("invalid_collection").Context(wrapper.Context)
+ isUpdated, err := updateOp.One(simpleDoc)
+ assert.False(t, isUpdated)
+ assert.Error(t, err)
- // insert duplicate document
- insertedId, err = storage.InsertOne(testContext, testCollection,
- UnorderedDocument{"_id": danteId, "name": "Dante Alighieri", "description": "blablabla"})
- if insertedId != nil || err == nil {
- t.Fatal("inserting duplicate id must fail")
- }
+ updated, err := updateOp.Many(simpleDoc)
+ assert.Zero(t, updated)
+ assert.Error(t, err)
+
+ findOp := wrapper.Storage.Find("invalid_collection").Context(wrapper.Context)
+ var result interface{}
+ err = findOp.First(&result)
+ assert.Nil(t, result)
+ assert.Error(t, err)
+
+ var results interface{}
+ err = findOp.All(&result)
+ assert.Nil(t, results)
+ assert.Error(t, err)
+
+ wrapper.Destroy(t)
}
-func testFindOne(t *testing.T) {
- // find a document in an invalid connection
- result, err := storage.FindOne(testContext, "invalid_collection",
- OrderedDocument{{"key", "invalid"}})
- if result != nil || err == nil {
- t.Fatal("find a document in an invalid collections must fail")
- }
+func TestSimpleInsertAndFind(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ collectionName := "simple_insert_find"
+ wrapper.AddCollection(collectionName)
- // find an existing document
- result, err = storage.FindOne(testContext, testCollection, OrderedDocument{{"_id", "000000"}})
- if err != nil {
- t.Fatal(err)
- }
- if result == nil {
- t.Fatal("FindOne cannot find the valid document")
- }
- name, ok := result["name"]
- if !ok || name != "Dante Alighieri" {
- t.Fatal("document retrieved with FindOne is invalid")
- }
+ insertOp := wrapper.Storage.Insert(collectionName).Context(wrapper.Context)
+ simpleDocA := UnorderedDocument{"key": "a"}
+ idA, err := insertOp.One(simpleDocA)
+ assert.Len(t, idA, 12)
+ assert.Nil(t, err)
- // find an existing document
- result, err = storage.FindOne(testContext, testCollection, OrderedDocument{{"_id", "invalid_id"}})
- if err != nil {
- t.Fatal(err)
- }
- if result != nil {
- t.Fatal("FindOne cannot find an invalid document")
+ simpleDocB := UnorderedDocument{"_id": "idb", "key": "b"}
+ idB, err := insertOp.One(simpleDocB)
+ assert.Equal(t, "idb", idB)
+ assert.Nil(t, err)
+
+ var result UnorderedDocument
+ findOp := wrapper.Storage.Find(collectionName).Context(wrapper.Context)
+ err = findOp.Filter(OrderedDocument{{"key", "a"}}).First(&result)
+ assert.Nil(t, err)
+ assert.Equal(t, idA, result["_id"])
+ assert.Equal(t, simpleDocA["key"], result["key"])
+
+ err = findOp.Filter(OrderedDocument{{"_id", idB}}).First(&result)
+ assert.Nil(t, err)
+ assert.Equal(t, idB, result["_id"])
+ assert.Equal(t, simpleDocB["key"], result["key"])
+
+ wrapper.Destroy(t)
+}
+
+func TestSimpleInsertManyAndFindMany(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ collectionName := "simple_insert_many_find_many"
+ wrapper.AddCollection(collectionName)
+
+ insertOp := wrapper.Storage.Insert(collectionName).Context(wrapper.Context)
+ simpleDocs := []interface{}{
+ UnorderedDocument{"key": "a"},
+ UnorderedDocument{"_id": "idb", "key": "b"},
+ UnorderedDocument{"key": "c"},
}
+ ids, err := insertOp.Many(simpleDocs)
+ assert.Nil(t, err)
+ assert.Len(t, ids, 3)
+ assert.Equal(t, "idb", ids[1])
+
+ var results []UnorderedDocument
+ findOp := wrapper.Storage.Find(collectionName).Context(wrapper.Context)
+ err = findOp.Sort("key", false).All(&results) // test sort ascending
+ assert.Nil(t, err)
+ assert.Len(t, results, 3)
+ assert.Equal(t, "c", results[0]["key"])
+ assert.Equal(t, "b", results[1]["key"])
+ assert.Equal(t, "a", results[2]["key"])
+
+ err = findOp.Sort("key", true).All(&results) // test sort descending
+ assert.Nil(t, err)
+ assert.Len(t, results, 3)
+ assert.Equal(t, "c", results[2]["key"])
+ assert.Equal(t, "b", results[1]["key"])
+ assert.Equal(t, "a", results[0]["key"])
+
+ err = findOp.Filter(OrderedDocument{{"key", OrderedDocument{{"$gte", "b"}}}}).
+ Sort("key", true).All(&results) // test filter
+ assert.Nil(t, err)
+ assert.Len(t, results, 2)
+ assert.Equal(t, "b", results[0]["key"])
+ assert.Equal(t, "c", results[1]["key"])
+
+ wrapper.Destroy(t)
}
-func TestBasicOperations(t *testing.T) {
- t.Run("testInsert", testInsert)
- t.Run("testFindOne", testFindOne)
+func TestSimpleUpdateOneUpdateMany(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ collectionName := "simple_update_one_update_many"
+ wrapper.AddCollection(collectionName)
+
+ insertOp := wrapper.Storage.Insert(collectionName).Context(wrapper.Context)
+ simpleDocs := []interface{}{
+ UnorderedDocument{"_id": "ida", "key": "a"},
+ UnorderedDocument{"key": "b"},
+ UnorderedDocument{"key": "c"},
+ }
+ _, err := insertOp.Many(simpleDocs)
+ assert.Nil(t, err)
+
+ updateOp := wrapper.Storage.Update(collectionName).Context(wrapper.Context)
+ isUpdated, err := updateOp.Filter(OrderedDocument{{"_id", "ida"}}).
+ One(OrderedDocument{{"key", "aa"}})
+ assert.Nil(t, err)
+ assert.True(t, isUpdated)
+
+ updated, err := updateOp.Filter(OrderedDocument{{"key", OrderedDocument{{"$gte", "b"}}}}).
+ Many(OrderedDocument{{"key", "bb"}})
+ assert.Nil(t, err)
+ assert.Equal(t, int64(2), updated)
+
+ var upsertId interface{}
+ isUpdated, err = updateOp.Upsert(&upsertId).Filter(OrderedDocument{{"key", "d"}}).
+ One(OrderedDocument{{"key", "d"}})
+ assert.Nil(t, err)
+ assert.False(t, isUpdated)
+ assert.NotNil(t, upsertId)
+
+ var results []UnorderedDocument
+ findOp := wrapper.Storage.Find(collectionName).Context(wrapper.Context)
+ err = findOp.Sort("key", true).All(&results) // test sort ascending
+ assert.Nil(t, err)
+ assert.Len(t, results, 4)
+ assert.Equal(t, "aa", results[0]["key"])
+ assert.Equal(t, "bb", results[1]["key"])
+ assert.Equal(t, "bb", results[2]["key"])
+ assert.Equal(t, "d", results[3]["key"])
+
+ wrapper.Destroy(t)
}
-func TestInsertManyFindDocuments(t *testing.T) {
+func TestComplexInsertManyFindMany(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ collectionName := "complex_insert_many_find_many"
+ wrapper.AddCollection(collectionName)
+
testTime := time.Now()
oid1, err := primitive.ObjectIDFromHex("ffffffffffffffffffffffff")
assert.Nil(t, err)
@@ -117,7 +190,7 @@ func TestInsertManyFindDocuments(t *testing.T) {
B: 0,
C: testTime,
D: map[string]b{
- "first": {A: "0", B: 0},
+ "first": {A: "0", B: 0},
"second": {A: "1", B: 1},
},
E: []b{
@@ -126,22 +199,22 @@ func TestInsertManyFindDocuments(t *testing.T) {
},
a{
Id: oid1,
- A: "test1",
- B: 1,
- C: testTime,
- D: map[string]b{},
- E: []b{},
+ A: "test1",
+ B: 1,
+ C: testTime,
+ D: map[string]b{},
+ E: []b{},
},
a{},
}
- ids, err := storage.InsertMany(testContext, testInsertManyFindCollection, docs)
+ ids, err := wrapper.Storage.Insert(collectionName).Context(wrapper.Context).Many(docs)
assert.Nil(t, err)
assert.Len(t, ids, 3)
assert.Equal(t, ids[1], oid1)
var results []a
- err = storage.Find(testContext, testInsertManyFindCollection, NoFilters, &results)
+ err = wrapper.Storage.Find(collectionName).Context(wrapper.Context).All(&results)
assert.Nil(t, err)
assert.Len(t, results, 3)
doc0, doc1, doc2 := docs[0].(a), docs[1].(a), docs[2].(a)
@@ -163,57 +236,6 @@ func TestInsertManyFindDocuments(t *testing.T) {
assert.Equal(t, doc0.E, results[0].E)
assert.Nil(t, results[1].E)
assert.Nil(t, results[2].E)
-}
-type testStorage struct {
- insertFunc func(ctx context.Context, collectionName string, document interface{}) (interface{}, error)
- insertManyFunc func(ctx context.Context, collectionName string, document []interface{}) ([]interface{}, error)
- updateOne func(ctx context.Context, collectionName string, filter interface{}, update interface {}, upsert bool) (interface{}, error)
- updateMany func(ctx context.Context, collectionName string, filter interface{}, update interface {}, upsert bool) (interface{}, error)
- findOne func(ctx context.Context, collectionName string, filter interface{}) (UnorderedDocument, error)
- find func(ctx context.Context, collectionName string, filter interface{}, results interface{}) error
-}
-
-func (ts testStorage) InsertOne(ctx context.Context, collectionName string, document interface{}) (interface{}, error) {
- if ts.insertFunc != nil {
- return ts.insertFunc(ctx, collectionName, document)
- }
- return nil, errors.New("not implemented")
-}
-
-func (ts testStorage) InsertMany(ctx context.Context, collectionName string, document []interface{}) ([]interface{}, error) {
- if ts.insertFunc != nil {
- return ts.insertManyFunc(ctx, collectionName, document)
- }
- return nil, errors.New("not implemented")
-}
-
-func (ts testStorage) UpdateOne(ctx context.Context, collectionName string, filter interface{}, update interface {},
- upsert bool) (interface{}, error) {
- if ts.updateOne != nil {
- return ts.updateOne(ctx, collectionName, filter, update, upsert)
- }
- return nil, errors.New("not implemented")
-}
-
-func (ts testStorage) UpdateMany(ctx context.Context, collectionName string, filter interface{}, update interface {},
- upsert bool) (interface{}, error) {
- if ts.updateOne != nil {
- return ts.updateMany(ctx, collectionName, filter, update, upsert)
- }
- return nil, errors.New("not implemented")
-}
-
-func (ts testStorage) FindOne(ctx context.Context, collectionName string, filter interface{}) (UnorderedDocument, error) {
- if ts.findOne != nil {
- return ts.findOne(ctx, collectionName, filter)
- }
- return nil, errors.New("not implemented")
-}
-
-func (ts testStorage) Find(ctx context.Context, collectionName string, filter interface{}, results interface{}) error {
- if ts.find != nil {
- return ts.find(ctx, collectionName, filter, results)
- }
- return errors.New("not implemented")
+ wrapper.Destroy(t)
}
diff --git a/stream_handler.go b/stream_handler.go
index ce580fc..3fafa21 100644
--- a/stream_handler.go
+++ b/stream_handler.go
@@ -2,11 +2,9 @@ package main
import (
"bytes"
- "encoding/binary"
- "fmt"
"github.com/flier/gohs/hyperscan"
"github.com/google/gopacket/tcpassembly"
- "log"
+ log "github.com/sirupsen/logrus"
"time"
)
@@ -28,30 +26,28 @@ type StreamHandler struct {
currentIndex int
firstPacketSeen time.Time
lastPacketSeen time.Time
- documentsKeys []string
+ documentsKeys []RowID
streamLength int
patternStream hyperscan.Stream
patternMatches map[uint][]PatternSlice
}
-type PatternSlice [2]uint64
-
// NewReaderStream returns a new StreamHandler object.
func NewStreamHandler(connection ConnectionHandler, key StreamKey, scratch *hyperscan.Scratch) StreamHandler {
handler := StreamHandler{
- connection: connection,
- streamKey: key,
- buffer: new(bytes.Buffer),
- indexes: make([]int, 0, InitialBlockCount),
- timestamps: make([]time.Time, 0, InitialBlockCount),
- lossBlocks: make([]bool, 0, InitialBlockCount),
- documentsKeys: make([]string, 0, 1), // most of the time the stream fit in one document
+ connection: connection,
+ streamKey: key,
+ buffer: new(bytes.Buffer),
+ indexes: make([]int, 0, InitialBlockCount),
+ timestamps: make([]time.Time, 0, InitialBlockCount),
+ lossBlocks: make([]bool, 0, InitialBlockCount),
+ documentsKeys: make([]RowID, 0, 1), // most of the time the stream fit in one document
patternMatches: make(map[uint][]PatternSlice, 10), // TODO: change with exactly value
}
stream, err := connection.Patterns().Open(0, scratch, handler.onMatch, nil)
if err != nil {
- log.Println("failed to create a stream: ", err)
+ log.WithField("streamKey", key).WithError(err).Error("failed to create a stream")
}
handler.patternStream = stream
@@ -81,7 +77,7 @@ func (sh *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) {
}
n, err := sh.buffer.Write(r.Bytes[skip:])
if err != nil {
- log.Println("error while copying bytes from Reassemble in stream_handler")
+ log.WithError(err).Error("failed to copy bytes from a Reassemble")
return
}
sh.indexes = append(sh.indexes, sh.currentIndex)
@@ -92,7 +88,7 @@ func (sh *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) {
err = sh.patternStream.Scan(r.Bytes)
if err != nil {
- log.Println("failed to scan packet buffer: ", err)
+ log.WithError(err).Error("failed to scan packet buffer")
}
}
}
@@ -101,7 +97,7 @@ func (sh *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) {
func (sh *StreamHandler) ReassemblyComplete() {
err := sh.patternStream.Close()
if err != nil {
- log.Println("failed to close pattern stream: ", err)
+ log.WithError(err).Error("failed to close pattern stream")
}
if sh.currentIndex > 0 {
@@ -144,33 +140,26 @@ func (sh *StreamHandler) onMatch(id uint, from uint64, to uint64, flags uint, co
}
func (sh *StreamHandler) storageCurrentDocument() {
- streamKey := sh.generateDocumentKey()
-
- _, err := sh.connection.Storage().InsertOne(sh.connection.Context(), "connection_streams", OrderedDocument{
- {"_id", streamKey},
- {"connection_id", nil},
- {"document_index", len(sh.documentsKeys)},
- {"payload", sh.buffer.Bytes()},
- {"blocks_indexes", sh.indexes},
- {"blocks_timestamps", sh.timestamps},
- {"blocks_loss", sh.lossBlocks},
- {"pattern_matches", sh.patternMatches},
- })
+ payload := (sh.streamKey[0].FastHash()^sh.streamKey[1].FastHash()^sh.streamKey[2].FastHash()^
+ sh.streamKey[3].FastHash())&uint64(0xffffffffffffff00) | uint64(len(sh.documentsKeys)) // LOL
+ streamKey := sh.connection.Storage().NewCustomRowID(payload, sh.firstPacketSeen)
+
+ _, err := sh.connection.Storage().Insert(ConnectionStreams).
+ Context(sh.connection.Context()).
+ One(ConnectionStream{
+ ID: streamKey,
+ ConnectionID: ZeroRowID,
+ DocumentIndex: len(sh.documentsKeys),
+ Payload: sh.buffer.Bytes(),
+ BlocksIndexes: sh.indexes,
+ BlocksTimestamps: sh.timestamps,
+ BlocksLoss: sh.lossBlocks,
+ PatternMatches: sh.patternMatches,
+ })
if err != nil {
- log.Println("failed to insert connection stream: ", err)
+ log.WithError(err).Error("failed to insert connection stream")
}
sh.documentsKeys = append(sh.documentsKeys, streamKey)
}
-
-func (sh *StreamHandler) generateDocumentKey() string {
- hash := make([]byte, 16)
- endpointsHash := sh.streamKey[0].FastHash() ^ sh.streamKey[1].FastHash() ^
- sh.streamKey[2].FastHash() ^ sh.streamKey[3].FastHash()
- binary.BigEndian.PutUint64(hash, endpointsHash)
- binary.BigEndian.PutUint64(hash[8:], uint64(sh.firstPacketSeen.UnixNano()))
- binary.BigEndian.PutUint16(hash[8:], uint16(len(sh.documentsKeys)))
-
- return fmt.Sprintf("%x", hash)
-}
diff --git a/stream_handler_test.go b/stream_handler_test.go
index cb5ecc7..ece3190 100644
--- a/stream_handler_test.go
+++ b/stream_handler_test.go
@@ -2,7 +2,6 @@ package main
import (
"context"
- "fmt"
"github.com/flier/gohs/hyperscan"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/tcpassembly"
@@ -19,13 +18,14 @@ const testDstIp = "10.10.10.1"
const srcPort = 44444
const dstPort = 8080
-
func TestReassemblingEmptyStream(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ wrapper.AddCollection(ConnectionStreams)
patterns, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0))
- require.Nil(t, err)
+ require.NoError(t, err)
scratch, err := hyperscan.NewScratch(patterns)
- require.Nil(t, err)
- streamHandler := createTestStreamHandler(testStorage{}, patterns, scratch)
+ require.NoError(t, err)
+ streamHandler := createTestStreamHandler(wrapper, patterns, scratch)
streamHandler.Reassembled([]tcpassembly.Reassembly{{
Bytes: []byte{},
@@ -51,19 +51,20 @@ func TestReassemblingEmptyStream(t *testing.T) {
assert.Equal(t, true, completed)
err = scratch.Free()
- require.Nil(t, err, "free scratch")
+ require.NoError(t, err, "free scratch")
err = patterns.Close()
- require.Nil(t, err, "close stream database")
+ require.NoError(t, err, "close stream database")
+ wrapper.Destroy(t)
}
-
func TestReassemblingSingleDocument(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ wrapper.AddCollection(ConnectionStreams)
patterns, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/impossible_to_match/", 0))
- require.Nil(t, err)
+ require.NoError(t, err)
scratch, err := hyperscan.NewScratch(patterns)
- require.Nil(t, err)
- storage := &testStorage{}
- streamHandler := createTestStreamHandler(storage, patterns, scratch)
+ require.NoError(t, err)
+ streamHandler := createTestStreamHandler(wrapper, patterns, scratch)
payloadLen := 256
firstTime := time.Unix(0, 0)
@@ -71,10 +72,10 @@ func TestReassemblingSingleDocument(t *testing.T) {
lastTime := time.Unix(20, 0)
data := make([]byte, MaxDocumentSize)
rand.Read(data)
- reassembles := make([]tcpassembly.Reassembly, MaxDocumentSize / payloadLen)
- indexes := make([]int, MaxDocumentSize / payloadLen)
- timestamps := make([]time.Time, MaxDocumentSize / payloadLen)
- lossBlocks := make([]bool, MaxDocumentSize / payloadLen)
+ reassembles := make([]tcpassembly.Reassembly, MaxDocumentSize/payloadLen)
+ indexes := make([]int, MaxDocumentSize/payloadLen)
+ timestamps := make([]time.Time, MaxDocumentSize/payloadLen)
+ lossBlocks := make([]bool, MaxDocumentSize/payloadLen)
for i := 0; i < len(reassembles); i++ {
var seen time.Time
if i == 0 {
@@ -86,36 +87,22 @@ func TestReassemblingSingleDocument(t *testing.T) {
}
reassembles[i] = tcpassembly.Reassembly{
- Bytes: data[i*payloadLen:(i+1)*payloadLen],
+ Bytes: data[i*payloadLen : (i+1)*payloadLen],
Skip: 0,
Start: i == 0,
End: i == len(reassembles)-1,
Seen: seen,
}
- indexes[i] = i*payloadLen
+ indexes[i] = i * payloadLen
timestamps[i] = seen
}
- inserted := false
- storage.insertFunc = func(ctx context.Context, collectionName string, document interface{}) (i interface{}, err error) {
- od := document.(OrderedDocument)
- assert.Equal(t, "connection_streams", collectionName)
- assert.Equal(t, "bb41a60281cfae830000000000000000", od[0].Value)
- assert.Equal(t, nil, od[1].Value)
- assert.Equal(t, 0, od[2].Value)
- assert.Equal(t, data, od[3].Value)
- assert.Equal(t, indexes, od[4].Value)
- assert.Equal(t, timestamps, od[5].Value)
- assert.Equal(t, lossBlocks, od[6].Value)
- assert.Len(t, od[7].Value, 0)
- inserted = true
- return nil, nil
- }
+ var results []ConnectionStream
streamHandler.Reassembled(reassembles)
- if !assert.Equal(t, false, inserted) {
- inserted = false
- }
+ err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results)
+ require.NoError(t, err)
+ assert.Len(t, results, 0)
completed := false
streamHandler.connection.(*testConnectionHandler).onComplete = func(handler *StreamHandler) {
@@ -123,6 +110,18 @@ func TestReassemblingSingleDocument(t *testing.T) {
}
streamHandler.ReassemblyComplete()
+ err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results)
+ require.NoError(t, err)
+ assert.Len(t, results, 1)
+ assert.Equal(t, firstTime.Unix(), results[0].ID.Timestamp().Unix())
+ assert.Zero(t, results[0].ConnectionID)
+ assert.Equal(t, 0, results[0].DocumentIndex)
+ assert.Equal(t, data, results[0].Payload)
+ assert.Equal(t, indexes, results[0].BlocksIndexes)
+ assert.Len(t, results[0].BlocksTimestamps, len(timestamps)) // should be compared one by one
+ assert.Equal(t, lossBlocks, results[0].BlocksLoss)
+ assert.Len(t, results[0].PatternMatches, 0)
+
assert.Equal(t, len(data), streamHandler.currentIndex)
assert.Equal(t, firstTime, streamHandler.firstPacketSeen)
assert.Equal(t, lastTime, streamHandler.lastPacketSeen)
@@ -130,35 +129,35 @@ func TestReassemblingSingleDocument(t *testing.T) {
assert.Equal(t, len(data), streamHandler.streamLength)
assert.Len(t, streamHandler.patternMatches, 0)
- assert.Equal(t, true, inserted, "inserted")
assert.Equal(t, true, completed, "completed")
err = scratch.Free()
- require.Nil(t, err, "free scratch")
+ require.NoError(t, err, "free scratch")
err = patterns.Close()
- require.Nil(t, err, "close stream database")
+ require.NoError(t, err, "close stream database")
+ wrapper.Destroy(t)
}
-
func TestReassemblingMultipleDocuments(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ wrapper.AddCollection(ConnectionStreams)
patterns, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/impossible_to_match/", 0))
- require.Nil(t, err)
+ require.NoError(t, err)
scratch, err := hyperscan.NewScratch(patterns)
- require.Nil(t, err)
- storage := &testStorage{}
- streamHandler := createTestStreamHandler(storage, patterns, scratch)
+ require.NoError(t, err)
+ streamHandler := createTestStreamHandler(wrapper, patterns, scratch)
payloadLen := 256
firstTime := time.Unix(0, 0)
middleTime := time.Unix(10, 0)
lastTime := time.Unix(20, 0)
- dataSize := MaxDocumentSize*2
+ dataSize := MaxDocumentSize * 2
data := make([]byte, dataSize)
rand.Read(data)
- reassembles := make([]tcpassembly.Reassembly, dataSize / payloadLen)
- indexes := make([]int, dataSize / payloadLen)
- timestamps := make([]time.Time, dataSize / payloadLen)
- lossBlocks := make([]bool, dataSize / payloadLen)
+ reassembles := make([]tcpassembly.Reassembly, dataSize/payloadLen)
+ indexes := make([]int, dataSize/payloadLen)
+ timestamps := make([]time.Time, dataSize/payloadLen)
+ lossBlocks := make([]bool, dataSize/payloadLen)
for i := 0; i < len(reassembles); i++ {
var seen time.Time
if i == 0 {
@@ -170,38 +169,22 @@ func TestReassemblingMultipleDocuments(t *testing.T) {
}
reassembles[i] = tcpassembly.Reassembly{
- Bytes: data[i*payloadLen:(i+1)*payloadLen],
+ Bytes: data[i*payloadLen : (i+1)*payloadLen],
Skip: 0,
Start: i == 0,
End: i == len(reassembles)-1,
Seen: seen,
}
- indexes[i] = i*payloadLen % MaxDocumentSize
+ indexes[i] = i * payloadLen % MaxDocumentSize
timestamps[i] = seen
}
- inserted := 0
- storage.insertFunc = func(ctx context.Context, collectionName string, document interface{}) (i interface{}, err error) {
- od := document.(OrderedDocument)
- blockLen := MaxDocumentSize / payloadLen
- assert.Equal(t, "connection_streams", collectionName)
- assert.Equal(t, fmt.Sprintf("bb41a60281cfae83000%v000000000000", inserted), od[0].Value)
- assert.Equal(t, nil, od[1].Value)
- assert.Equal(t, inserted, od[2].Value)
- assert.Equal(t, data[MaxDocumentSize*inserted:MaxDocumentSize*(inserted+1)], od[3].Value)
- assert.Equal(t, indexes[blockLen*inserted:blockLen*(inserted+1)], od[4].Value)
- assert.Equal(t, timestamps[blockLen*inserted:blockLen*(inserted+1)], od[5].Value)
- assert.Equal(t, lossBlocks[blockLen*inserted:blockLen*(inserted+1)], od[6].Value)
- assert.Len(t, od[7].Value, 0)
- inserted += 1
-
- return nil, nil
- }
-
streamHandler.Reassembled(reassembles)
- if !assert.Equal(t, 1, inserted) {
- inserted = 1
- }
+
+ var results []ConnectionStream
+ err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results)
+ require.NoError(t, err)
+ assert.Len(t, results, 1)
completed := false
streamHandler.connection.(*testConnectionHandler).onComplete = func(handler *StreamHandler) {
@@ -209,6 +192,21 @@ func TestReassemblingMultipleDocuments(t *testing.T) {
}
streamHandler.ReassemblyComplete()
+ err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results)
+ require.NoError(t, err)
+ assert.Len(t, results, 2)
+ for i := 0; i < 2; i++ {
+ blockLen := MaxDocumentSize / payloadLen
+ assert.Equal(t, firstTime.Unix(), results[i].ID.Timestamp().Unix())
+ assert.Zero(t, results[i].ConnectionID)
+ assert.Equal(t, i, results[i].DocumentIndex)
+ assert.Equal(t, data[MaxDocumentSize*i:MaxDocumentSize*(i+1)], results[i].Payload)
+ assert.Equal(t, indexes[blockLen*i:blockLen*(i+1)], results[i].BlocksIndexes)
+ assert.Len(t, results[i].BlocksTimestamps, len(timestamps[blockLen*i:blockLen*(i+1)])) // should be compared one by one
+ assert.Equal(t, lossBlocks[blockLen*i:blockLen*(i+1)], results[i].BlocksLoss)
+ assert.Len(t, results[i].PatternMatches, 0)
+ }
+
assert.Equal(t, MaxDocumentSize, streamHandler.currentIndex)
assert.Equal(t, firstTime, streamHandler.firstPacketSeen)
assert.Equal(t, lastTime, streamHandler.lastPacketSeen)
@@ -216,26 +214,28 @@ func TestReassemblingMultipleDocuments(t *testing.T) {
assert.Equal(t, len(data), streamHandler.streamLength)
assert.Len(t, streamHandler.patternMatches, 0)
- assert.Equal(t, 2, inserted, "inserted")
assert.Equal(t, true, completed, "completed")
err = scratch.Free()
- require.Nil(t, err, "free scratch")
+ require.NoError(t, err, "free scratch")
err = patterns.Close()
- require.Nil(t, err, "close stream database")
+ require.NoError(t, err, "close stream database")
+ wrapper.Destroy(t)
}
func TestReassemblingPatternMatching(t *testing.T) {
+ wrapper := NewTestStorageWrapper(t)
+ wrapper.AddCollection(ConnectionStreams)
a, err := hyperscan.ParsePattern("/a{8}/i")
- require.Nil(t, err)
+ require.NoError(t, err)
a.Id = 0
a.Flags |= hyperscan.SomLeftMost
b, err := hyperscan.ParsePattern("/b[c]+b/i")
- require.Nil(t, err)
+ require.NoError(t, err)
b.Id = 1
b.Flags |= hyperscan.SomLeftMost
d, err := hyperscan.ParsePattern("/[d]+e[d]+/i")
- require.Nil(t, err)
+ require.NoError(t, err)
d.Id = 2
d.Flags |= hyperscan.SomLeftMost
@@ -247,30 +247,12 @@ func TestReassemblingPatternMatching(t *testing.T) {
}
patterns, err := hyperscan.NewStreamDatabase(a, b, d)
- require.Nil(t, err)
+ require.NoError(t, err)
scratch, err := hyperscan.NewScratch(patterns)
- require.Nil(t, err)
- storage := &testStorage{}
- streamHandler := createTestStreamHandler(storage, patterns, scratch)
+ require.NoError(t, err)
+ streamHandler := createTestStreamHandler(wrapper, patterns, scratch)
seen := time.Unix(0, 0)
- inserted := false
- storage.insertFunc = func(ctx context.Context, collectionName string, document interface{}) (i interface{}, err error) {
- od := document.(OrderedDocument)
- assert.Equal(t, "connection_streams", collectionName)
- assert.Equal(t, "bb41a60281cfae830000000000000000", od[0].Value)
- assert.Equal(t, nil, od[1].Value)
- assert.Equal(t, 0, od[2].Value)
- assert.Equal(t, []byte(payload), od[3].Value)
- assert.Equal(t, []int{0}, od[4].Value)
- assert.Equal(t, []time.Time{seen}, od[5].Value)
- assert.Equal(t, []bool{false}, od[6].Value)
- assert.Equal(t, expected, od[7].Value)
- inserted = true
-
- return nil, nil
- }
-
streamHandler.Reassembled([]tcpassembly.Reassembly{{
Bytes: []byte(payload),
Skip: 0,
@@ -278,7 +260,11 @@ func TestReassemblingPatternMatching(t *testing.T) {
End: true,
Seen: seen,
}})
- assert.Equal(t, false, inserted)
+
+ var results []ConnectionStream
+ err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results)
+ require.NoError(t, err)
+ assert.Len(t, results, 0)
completed := false
streamHandler.connection.(*testConnectionHandler).onComplete = func(handler *StreamHandler) {
@@ -286,26 +272,36 @@ func TestReassemblingPatternMatching(t *testing.T) {
}
streamHandler.ReassemblyComplete()
+ err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results)
+ require.NoError(t, err)
+ assert.Len(t, results, 1)
+ assert.Equal(t, seen.Unix(), results[0].ID.Timestamp().Unix())
+ assert.Zero(t, results[0].ConnectionID)
+ assert.Equal(t, 0, results[0].DocumentIndex)
+ assert.Equal(t, []byte(payload), results[0].Payload)
+ assert.Equal(t, []int{0}, results[0].BlocksIndexes)
+ assert.Len(t, results[0].BlocksTimestamps, 1) // should be compared one by one
+ assert.Equal(t, []bool{false}, results[0].BlocksLoss)
+ assert.Equal(t, expected, results[0].PatternMatches)
+
assert.Equal(t, len(payload), streamHandler.currentIndex)
assert.Equal(t, seen, streamHandler.firstPacketSeen)
assert.Equal(t, seen, streamHandler.lastPacketSeen)
assert.Len(t, streamHandler.documentsKeys, 1)
assert.Equal(t, len(payload), streamHandler.streamLength)
- assert.Equal(t, true, inserted, "inserted")
assert.Equal(t, true, completed, "completed")
err = scratch.Free()
- require.Nil(t, err, "free scratch")
+ require.NoError(t, err, "free scratch")
err = patterns.Close()
- require.Nil(t, err, "close stream database")
+ require.NoError(t, err, "close stream database")
+ wrapper.Destroy(t)
}
-
-func createTestStreamHandler(storage Storage, patterns hyperscan.StreamDatabase, scratch *hyperscan.Scratch) StreamHandler {
+func createTestStreamHandler(wrapper *TestStorageWrapper, patterns hyperscan.StreamDatabase, scratch *hyperscan.Scratch) StreamHandler {
testConnectionHandler := &testConnectionHandler{
- storage: storage,
- context: context.Background(),
+ wrapper: wrapper,
patterns: patterns,
}
@@ -318,18 +314,17 @@ func createTestStreamHandler(storage Storage, patterns hyperscan.StreamDatabase,
}
type testConnectionHandler struct {
- storage Storage
- context context.Context
- patterns hyperscan.StreamDatabase
+ wrapper *TestStorageWrapper
+ patterns hyperscan.StreamDatabase
onComplete func(*StreamHandler)
}
func (tch *testConnectionHandler) Storage() Storage {
- return tch.storage
+ return tch.wrapper.Storage
}
func (tch *testConnectionHandler) Context() context.Context {
- return tch.context
+ return tch.wrapper.Context
}
func (tch *testConnectionHandler) Patterns() hyperscan.StreamDatabase {