From 0520dab47d61e2c4de246459bf4f5c72d69182d3 Mon Sep 17 00:00:00 2001 From: Emiliano Ciavatta Date: Thu, 9 Apr 2020 10:26:15 +0200 Subject: Refactor storage --- caronte_test.go | 60 +++++---- connection_handler.go | 14 +- connection_streams.go | 16 +++ go.mod | 1 + go.sum | 1 + pcap_importer.go | 22 ++-- routes.go | 3 +- rules_manager.go | 132 +++++++------------ storage.go | 339 +++++++++++++++++++++++++++++++------------------ storage_test.go | 300 +++++++++++++++++++++++-------------------- stream_handler.go | 71 +++++------ stream_handler_test.go | 219 ++++++++++++++++---------------- 12 files changed, 628 insertions(+), 550 deletions(-) create mode 100644 connection_streams.go diff --git a/caronte_test.go b/caronte_test.go index 9942086..2766640 100644 --- a/caronte_test.go +++ b/caronte_test.go @@ -4,21 +4,21 @@ import ( "context" "crypto/sha256" "fmt" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" - "log" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" "os" + "strconv" "testing" "time" ) -var storage Storage -var testContext context.Context - -const testInsertManyFindCollection = "testFi" -const testCollection = "characters" +type TestStorageWrapper struct { + DbName string + Storage *MongoStorage + Context context.Context +} -func TestMain(m *testing.M) { +func NewTestStorageWrapper(t *testing.T) *TestStorageWrapper { mongoHost, ok := os.LookupEnv("MONGO_HOST") if !ok { mongoHost = "localhost" @@ -27,33 +27,31 @@ func TestMain(m *testing.M) { if !ok { mongoPort = "27017" } + port, err := strconv.Atoi(mongoPort) + require.NoError(t, err, "invalid port") uniqueDatabaseName := sha256.Sum256([]byte(time.Now().String())) - - client, err := mongo.NewClient(options.Client().ApplyURI(fmt.Sprintf("mongodb://%s:%v", mongoHost, mongoPort))) - if err != nil { - panic("failed to create mongo client") - } - dbName := fmt.Sprintf("%x", uniqueDatabaseName[:31]) - db := client.Database(dbName) - log.Println("using database", dbName) - mongoStorage := MongoStorage{ - client: client, - collections: map[string]*mongo.Collection{ - testInsertManyFindCollection: db.Collection(testInsertManyFindCollection), - testCollection: db.Collection(testCollection), - }, - } + log.WithField("dbName", dbName).Info("creating new storage") + + storage := NewMongoStorage(mongoHost, port, dbName) + ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) - testContext, _ = context.WithTimeout(context.Background(), 10 * time.Second) + err = storage.Connect(ctx) + require.NoError(t, err, "failed to connect to database") - err = mongoStorage.Connect(testContext) - if err != nil { - panic(err) + return &TestStorageWrapper{ + DbName: dbName, + Storage: storage, + Context: ctx, } - storage = &mongoStorage +} + +func (tsw TestStorageWrapper) AddCollection(collectionName string) { + tsw.Storage.collections[collectionName] = tsw.Storage.client.Database(tsw.DbName).Collection(collectionName) +} - exitCode := m.Run() - os.Exit(exitCode) +func (tsw TestStorageWrapper) Destroy(t *testing.T) { + err := tsw.Storage.client.Disconnect(tsw.Context) + require.NoError(t, err, "failed to disconnect to database") } diff --git a/connection_handler.go b/connection_handler.go index 36dca6d..2e2fa84 100644 --- a/connection_handler.go +++ b/connection_handler.go @@ -64,7 +64,7 @@ func (factory *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcp initiator: initiator, mComplete: sync.Mutex{}, context: context.Background(), - patterns : factory.patterns, + patterns: factory.patterns, } factory.connections[key] = connection } @@ -122,7 +122,7 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) { ch.generateConnectionKey(startedAt) - _, err := ch.storage.InsertOne(ch.context, "connections", OrderedDocument{ + _, err := ch.storage.Insert("connections").Context(ch.context).One(OrderedDocument{ {"_id", ch.connectionKey}, {"ip_src", ch.initiator[0].String()}, {"ip_dst", ch.initiator[1].String()}, @@ -141,14 +141,14 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) { } streamsIds := append(client.documentsKeys, server.documentsKeys...) - n, err := ch.storage.UpdateOne(ch.context, "connection_streams", - UnorderedDocument{"_id": UnorderedDocument{"$in": streamsIds}}, - UnorderedDocument{"connection_id": ch.connectionKey}, - false) + n, err := ch.storage.Update("connection_streams"). + Context(ch.context). + Filter(OrderedDocument{{"_id", UnorderedDocument{"$in": streamsIds}}}). + Many(UnorderedDocument{"connection_id": ch.connectionKey}) if err != nil { log.Println("failed to update connection streams", err) } - if n != len(streamsIds) { + if int(n) != len(streamsIds) { log.Println("failed to update all connections streams") } } diff --git a/connection_streams.go b/connection_streams.go new file mode 100644 index 0000000..bede526 --- /dev/null +++ b/connection_streams.go @@ -0,0 +1,16 @@ +package main + +import "time" + +type ConnectionStream struct { + ID RowID `json:"id" bson:"_id"` + ConnectionID RowID `json:"connection_id" bson:"connection_id"` + DocumentIndex int `json:"document_index" bson:"document_index"` + Payload []byte `json:"payload" bson:"payload"` + BlocksIndexes []int `json:"blocks_indexes" bson:"blocks_indexes"` + BlocksTimestamps []time.Time `json:"blocks_timestamps" bson:"blocks_timestamps"` + BlocksLoss []bool `json:"blocks_loss" bson:"blocks_loss"` + PatternMatches map[uint][]PatternSlice `json:"pattern_matches" bson:"pattern_matches"` +} + +type PatternSlice [2]uint64 diff --git a/go.mod b/go.mod index 99863d7..ce12c15 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/google/gopacket v1.1.17 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.1 // indirect + github.com/sirupsen/logrus v1.4.2 github.com/stretchr/testify v1.4.0 go.mongodb.org/mongo-driver v1.3.1 golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 diff --git a/go.sum b/go.sum index 16a4ce9..d2863ce 100644 --- a/go.sum +++ b/go.sum @@ -103,6 +103,7 @@ github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/smartystreets/assertions v0.0.0-20180820201707-7c9eb446e3cf h1:6V1qxN6Usn4jy8unvggSJz/NC790tefw8Zdy6OZS5co= github.com/smartystreets/assertions v0.0.0-20180820201707-7c9eb446e3cf/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= diff --git a/pcap_importer.go b/pcap_importer.go index c6260de..ac7c0b5 100644 --- a/pcap_importer.go +++ b/pcap_importer.go @@ -22,7 +22,6 @@ const importUpdateProgressInterval = 3 * time.Second const initialPacketPerServicesMapSize = 16 const importedPcapsCollectionName = "imported_pcaps" - type PcapImporter struct { storage Storage streamPool *tcpassembly.StreamPool @@ -35,11 +34,10 @@ type PcapImporter struct { type flowCount [2]int - func NewPcapImporter(storage Storage, serverIp net.IP) *PcapImporter { serverEndpoint := layers.NewIPEndpoint(serverIp) streamFactory := &BiDirectionalStreamFactory{ - storage: storage, + storage: storage, serverIp: serverEndpoint, } streamPool := tcpassembly.NewStreamPool(streamFactory) @@ -82,7 +80,7 @@ func (pi *PcapImporter) ImportPcap(fileName string) (string, error) { {"importing_error", err}, } ctx, canc := context.WithCancel(context.Background()) - _, err = pi.storage.InsertOne(ctx, importedPcapsCollectionName, doc) + _, err = pi.storage.Insert(importedPcapsCollectionName).Context(ctx).One(doc) if err != nil { pi.mSessions.Unlock() _, alreadyProcessed := err.(mongo.WriteException) @@ -133,18 +131,18 @@ func (pi *PcapImporter) parsePcap(sessionId, fileName string, ctx context.Contex progressUpdate := func(completed bool, err error) { update := UnorderedDocument{ - "processed_packets": processedPackets, - "invalid_packets": invalidPackets, + "processed_packets": processedPackets, + "invalid_packets": invalidPackets, "packets_per_services": packetsPerService, - "importing_error": err, + "importing_error": err, } if completed { update["completed_at"] = time.Now() } - _, _err := pi.storage.UpdateOne(nil, importedPcapsCollectionName, OrderedDocument{{"_id", sessionId}}, - completed, false) - + _, _err := pi.storage.Update(importedPcapsCollectionName). + Filter(OrderedDocument{{"_id", sessionId}}). + One(nil) if _err != nil { log.Println("can't update importing statistics : ", _err) } @@ -158,10 +156,10 @@ func (pi *PcapImporter) parsePcap(sessionId, fileName string, ctx context.Contex for { select { - case <- ctx.Done(): + case <-ctx.Done(): handle.Close() deleteSession() - progressUpdate(false, errors.New("import process cancelled")) + progressUpdate(false, errors.New("import process cancelled")) return default: } diff --git a/routes.go b/routes.go index f44cff7..3759382 100644 --- a/routes.go +++ b/routes.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/go-playground/validator/v10" - "log" + log "github.com/sirupsen/logrus" "net/http" ) @@ -23,6 +23,7 @@ func ApplicationRoutes(engine *gin.Engine) { "error": fmt.Sprintf("field '%v' does not respect the %v(%v) rule", fieldErr.Field(), fieldErr.Tag(), fieldErr.Param()), }) + log.WithError(err).WithField("rule", rule).Panic("oops") return // exit on first error } } diff --git a/rules_manager.go b/rules_manager.go index d6e9aaa..e5a8d38 100644 --- a/rules_manager.go +++ b/rules_manager.go @@ -3,56 +3,53 @@ package main import ( "context" "crypto/sha256" - "encoding/json" "errors" "fmt" "github.com/flier/gohs/hyperscan" - "go.mongodb.org/mongo-driver/bson/primitive" - "log" + log "github.com/sirupsen/logrus" "sync" "time" ) - type RegexFlags struct { - Caseless bool `json:"caseless"` // Set case-insensitive matching. - DotAll bool `json:"dot_all"` // Matching a `.` will not exclude newlines. - MultiLine bool `json:"multi_line"` // Set multi-line anchoring. - SingleMatch bool `json:"single_match"` // Set single-match only mode. - Utf8Mode bool `json:"utf_8_mode"` // Enable UTF-8 mode for this expression. - UnicodeProperty bool `json:"unicode_property"` // Enable Unicode property support for this expression + Caseless bool `json:"caseless"` // Set case-insensitive matching. + DotAll bool `json:"dot_all"` // Matching a `.` will not exclude newlines. + MultiLine bool `json:"multi_line"` // Set multi-line anchoring. + SingleMatch bool `json:"single_match"` // Set single-match only mode. + Utf8Mode bool `json:"utf_8_mode"` // Enable UTF-8 mode for this expression. + UnicodeProperty bool `json:"unicode_property"` // Enable Unicode property support for this expression } type Pattern struct { - Regex string `json:"regex"` - Flags RegexFlags `json:"flags"` - MinOccurrences int `json:"min_occurrences"` - MaxOccurrences int `json:"max_occurrences"` - internalId int + Regex string `json:"regex"` + Flags RegexFlags `json:"flags"` + MinOccurrences int `json:"min_occurrences"` + MaxOccurrences int `json:"max_occurrences"` + internalId int compiledPattern *hyperscan.Pattern } type Filter struct { - ServicePort int + ServicePort int ClientAddress string - ClientPort int - MinDuration int - MaxDuration int - MinPackets int - MaxPackets int - MinSize int - MaxSize int + ClientPort int + MinDuration int + MaxDuration int + MinPackets int + MaxPackets int + MinSize int + MaxSize int } type Rule struct { - Id string `json:"-" bson:"_id,omitempty"` - Name string `json:"name" binding:"required,min=3" bson:"name"` - Color string `json:"color" binding:"required,hexcolor" bson:"color"` - Notes string `json:"notes" bson:"notes,omitempty"` - Enabled bool `json:"enabled" bson:"enabled"` + Id RowID `json:"-" bson:"_id,omitempty"` + Name string `json:"name" binding:"required,min=3" bson:"name"` + Color string `json:"color" binding:"required,hexcolor" bson:"color"` + Notes string `json:"notes" bson:"notes,omitempty"` + Enabled bool `json:"enabled" bson:"enabled"` Patterns []Pattern `json:"patterns" binding:"required,min=1" bson:"patterns"` - Filter Filter `json:"filter" bson:"filter,omitempty"` - Version int64 `json:"version" bson:"version"` + Filter Filter `json:"filter" bson:"filter,omitempty"` + Version int64 `json:"version" bson:"version"` } type RulesManager struct { @@ -67,57 +64,53 @@ type RulesManager struct { func NewRulesManager(storage Storage) RulesManager { return RulesManager{ - storage: storage, - rules: make(map[string]Rule), - patterns: make(map[string]Pattern), - mPatterns: sync.Mutex{}, + storage: storage, + rules: make(map[string]Rule), + patterns: make(map[string]Pattern), + mPatterns: sync.Mutex{}, } } - - func (rm RulesManager) LoadRules() error { var rules []Rule - if err := rm.storage.Find(nil, Rules, NoFilters, &rules); err != nil { + if err := rm.storage.Find(Rules).Sort("_id", true).All(&rules); err != nil { return err } - var version int64 for _, rule := range rules { if err := rm.validateAndAddRuleLocal(&rule); err != nil { - log.Printf("failed to import rule %s: %s\n", rule.Name, err) - continue - } - if rule.Version > version { - version = rule.Version + log.WithError(err).WithField("rule", rule).Warn("failed to import rule") } } rm.ruleIndex = len(rules) - return rm.generateDatabase(0) + return rm.generateDatabase(rules[len(rules)-1].Id) } func (rm RulesManager) AddRule(context context.Context, rule Rule) (string, error) { rm.mPatterns.Lock() - rule.Id = UniqueKey(time.Now(), uint32(rm.ruleIndex)) + rule.Id = rm.storage.NewCustomRowID(uint64(rm.ruleIndex), time.Now()) rule.Enabled = true if err := rm.validateAndAddRuleLocal(&rule); err != nil { rm.mPatterns.Unlock() return "", err } + + if err := rm.generateDatabase(rule.Id); err != nil { + rm.mPatterns.Unlock() + log.WithError(err).WithField("rule", rule).Panic("failed to generate database") + } rm.mPatterns.Unlock() - if _, err := rm.storage.InsertOne(context, Rules, rule); err != nil { - return "", err + if _, err := rm.storage.Insert(Rules).Context(context).One(rule); err != nil { + log.WithError(err).WithField("rule", rule).Panic("failed to insert rule on database") } - return rule.Id, rm.generateDatabase(rule.Id) + return rule.Id.Hex(), nil } - - func (rm RulesManager) validateAndAddRuleLocal(rule *Rule) error { if _, alreadyPresent := rm.rulesByName[rule.Name]; alreadyPresent { return errors.New("rule name must be unique") @@ -142,16 +135,18 @@ func (rm RulesManager) validateAndAddRuleLocal(rule *Rule) error { rm.patterns[key] = value } - rm.rules[rule.Id] = *rule + rm.rules[rule.Id.Hex()] = *rule rm.rulesByName[rule.Name] = *rule return nil } -func (rm RulesManager) generateDatabase(version string) error { +func (rm RulesManager) generateDatabase(version RowID) error { patterns := make([]*hyperscan.Pattern, len(rm.patterns)) + var i int for _, pattern := range rm.patterns { - patterns = append(patterns, pattern.compiledPattern) + patterns[i] = pattern.compiledPattern + i++ } database, err := hyperscan.NewStreamDatabase(patterns...) if err != nil { @@ -162,7 +157,6 @@ func (rm RulesManager) generateDatabase(version string) error { return nil } - func (p Pattern) BuildPattern() error { if p.compiledPattern != nil { return nil @@ -210,33 +204,3 @@ func (p Pattern) Hash() string { hash.Write([]byte(fmt.Sprintf("%s|%v|%v|%v", p.Regex, p.Flags, p.MinOccurrences, p.MaxOccurrences))) return fmt.Sprintf("%x", hash.Sum(nil)) } - -func test() { - user := &Pattern{Regex: "Frank"} - b, err := json.Marshal(user) - if err != nil { - fmt.Println(err) - return - } - fmt.Println(string(b)) - - p, _ := hyperscan.ParsePattern("/a/") - p1, _ := hyperscan.ParsePattern("/a/") - fmt.Println(p1.String(), p1.Flags) - //p1.Id = 1 - - fmt.Println(*p == *p1) - db, _ := hyperscan.NewBlockDatabase(p, p1) - s, _ := hyperscan.NewScratch(db) - db.Scan([]byte("Ciao"), s, onMatch, nil) - - - - -} - -func onMatch(id uint, from uint64, to uint64, flags uint, context interface{}) error { - fmt.Println(id) - - return nil -} \ No newline at end of file diff --git a/storage.go b/storage.go index 3be56d7..b88c5c8 100644 --- a/storage.go +++ b/storage.go @@ -6,60 +6,55 @@ import ( "encoding/hex" "errors" "fmt" - "time" - + log "github.com/sirupsen/logrus" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" + "time" ) +// Collections names const Connections = "connections" +const ConnectionStreams = "connection_streams" const ImportedPcaps = "imported_pcaps" const Rules = "rules" -var NoFilters = UnorderedDocument{} +const defaultConnectionTimeout = 10 * time.Second -const defaultConnectionTimeout = 10*time.Second -const defaultOperationTimeout = 3*time.Second +var ZeroRowID [12]byte type Storage interface { - InsertOne(ctx context.Context, collectionName string, document interface{}) (interface{}, error) - InsertMany(ctx context.Context, collectionName string, documents []interface{}) ([]interface{}, error) - UpdateOne(ctx context.Context, collectionName string, filter interface{}, update interface {}, upsert bool) (interface{}, error) - UpdateMany(ctx context.Context, collectionName string, filter interface{}, update interface {}, upsert bool) (interface{}, error) - FindOne(ctx context.Context, collectionName string, filter interface{}) (UnorderedDocument, error) - Find(ctx context.Context, collectionName string, filter interface{}, results interface{}) error + Insert(collectionName string) InsertOperation + Update(collectionName string) UpdateOperation + Find(collectionName string) FindOperation + NewCustomRowID(payload uint64, timestamp time.Time) RowID + NewRowID() RowID } type MongoStorage struct { - client *mongo.Client + client *mongo.Client collections map[string]*mongo.Collection } type OrderedDocument = bson.D type UnorderedDocument = bson.M - -func UniqueKey(timestamp time.Time, payload uint32) string { - var key [8]byte - binary.BigEndian.PutUint32(key[0:4], uint32(timestamp.Unix())) - binary.BigEndian.PutUint32(key[4:8], payload) - - return hex.EncodeToString(key[:]) -} +type Entry = bson.E +type RowID = primitive.ObjectID func NewMongoStorage(uri string, port int, database string) *MongoStorage { opt := options.Client() opt.ApplyURI(fmt.Sprintf("mongodb://%s:%v", uri, port)) client, err := mongo.NewClient(opt) if err != nil { - panic("Failed to create mongo client") + log.WithError(err).Panic("failed to create mongo client") } db := client.Database(database) colls := map[string]*mongo.Collection{ - Connections: db.Collection(Connections), + Connections: db.Collection(Connections), ImportedPcaps: db.Collection(ImportedPcaps), - Rules: db.Collection(Rules), + Rules: db.Collection(Rules), } return &MongoStorage{ @@ -76,19 +71,55 @@ func (storage *MongoStorage) Connect(ctx context.Context) error { return storage.client.Connect(ctx) } -func (storage *MongoStorage) InsertOne(ctx context.Context, collectionName string, - document interface{}) (interface{}, error) { +func (storage *MongoStorage) NewCustomRowID(payload uint64, timestamp time.Time) RowID { + var key [12]byte + binary.BigEndian.PutUint32(key[0:4], uint32(timestamp.Unix())) + binary.BigEndian.PutUint64(key[4:12], payload) - collection, ok := storage.collections[collectionName] - if !ok { - return nil, errors.New("invalid collection: " + collectionName) + if oid, err := primitive.ObjectIDFromHex(hex.EncodeToString(key[:])); err == nil { + return oid + } else { + log.WithError(err).Warn("failed to create object id") + return primitive.NewObjectID() } +} - if ctx == nil { - ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout) +func (storage *MongoStorage) NewRowID() RowID { + return primitive.NewObjectID() +} + +// InsertOne and InsertMany + +type InsertOperation interface { + Context(ctx context.Context) InsertOperation + StopOnFail(stop bool) InsertOperation + One(document interface{}) (interface{}, error) + Many(documents []interface{}) ([]interface{}, error) +} + +type MongoInsertOperation struct { + collection *mongo.Collection + ctx context.Context + optInsertMany *options.InsertManyOptions + err error +} + +func (fo MongoInsertOperation) Context(ctx context.Context) InsertOperation { + fo.ctx = ctx + return fo +} + +func (fo MongoInsertOperation) StopOnFail(stop bool) InsertOperation { + fo.optInsertMany.SetOrdered(stop) + return fo +} + +func (fo MongoInsertOperation) One(document interface{}) (interface{}, error) { + if fo.err != nil { + return nil, fo.err } - result, err := collection.InsertOne(ctx, document) + result, err := fo.collection.InsertOne(fo.ctx, document) if err != nil { return nil, err } @@ -96,149 +127,211 @@ func (storage *MongoStorage) InsertOne(ctx context.Context, collectionName strin return result.InsertedID, nil } -func (storage *MongoStorage) InsertMany(ctx context.Context, collectionName string, - documents []interface{}) ([]interface{}, error) { - - collection, ok := storage.collections[collectionName] - if !ok { - return nil, errors.New("invalid collection: " + collectionName) - } - - if ctx == nil { - ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout) +func (fo MongoInsertOperation) Many(documents []interface{}) ([]interface{}, error) { + if fo.err != nil { + return nil, fo.err } - result, err := collection.InsertMany(ctx, documents) + results, err := fo.collection.InsertMany(fo.ctx, documents, fo.optInsertMany) if err != nil { return nil, err } - return result.InsertedIDs, nil + return results.InsertedIDs, nil } -func (storage *MongoStorage) UpdateOne(ctx context.Context, collectionName string, - filter interface{}, update interface {}, upsert bool) (interface{}, error) { - +func (storage *MongoStorage) Insert(collectionName string) InsertOperation { collection, ok := storage.collections[collectionName] + op := MongoInsertOperation{ + collection: collection, + optInsertMany: options.InsertMany(), + } if !ok { - return nil, errors.New("invalid collection: " + collectionName) + op.err = errors.New("invalid collection: " + collectionName) } + return op +} - if ctx == nil { - ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout) - } +// UpdateOne and UpdateMany - opts := options.Update().SetUpsert(upsert) - update = bson.D{{"$set", update}} +type UpdateOperation interface { + Context(ctx context.Context) UpdateOperation + Filter(filter OrderedDocument) UpdateOperation + Upsert(upsertResults *interface{}) UpdateOperation + One(update interface{}) (bool, error) + Many(update interface{}) (int64, error) +} - result, err := collection.UpdateOne(ctx, filter, update, opts) - if err != nil { - return nil, err - } +type MongoUpdateOperation struct { + collection *mongo.Collection + filter OrderedDocument + update OrderedDocument + ctx context.Context + opt *options.UpdateOptions + upsertResult *interface{} + err error +} - if upsert { - return result.UpsertedID, nil - } +func (fo MongoUpdateOperation) Context(ctx context.Context) UpdateOperation { + fo.ctx = ctx + return fo +} - return result.ModifiedCount == 1, nil +func (fo MongoUpdateOperation) Filter(filter OrderedDocument) UpdateOperation { + fo.filter = filter + return fo } -func (storage *MongoStorage) UpdateMany(ctx context.Context, collectionName string, - filter interface{}, update interface {}, upsert bool) (interface{}, error) { +func (fo MongoUpdateOperation) Upsert(upsertResults *interface{}) UpdateOperation { + fo.upsertResult = upsertResults + fo.opt.SetUpsert(true) + return fo +} - collection, ok := storage.collections[collectionName] - if !ok { - return nil, errors.New("invalid collection: " + collectionName) +func (fo MongoUpdateOperation) One(update interface{}) (bool, error) { + if fo.err != nil { + return false, fo.err } - if ctx == nil { - ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout) + for i := range fo.update { + fo.update[i].Value = update + } + result, err := fo.collection.UpdateOne(fo.ctx, fo.filter, fo.update, fo.opt) + if err != nil { + return false, err } - opts := options.Update().SetUpsert(upsert) - update = bson.D{{"$set", update}} + if fo.upsertResult != nil { + *(fo.upsertResult) = result.UpsertedID + } + return result.ModifiedCount == 1, nil +} + +func (fo MongoUpdateOperation) Many(update interface{}) (int64, error) { + if fo.err != nil { + return 0, fo.err + } - result, err := collection.UpdateMany(ctx, filter, update, opts) + for i := range fo.update { + fo.update[i].Value = update + } + result, err := fo.collection.UpdateMany(fo.ctx, fo.filter, fo.update, fo.opt) if err != nil { - return nil, err + return 0, err } + if fo.upsertResult != nil { + *(fo.upsertResult) = result.UpsertedID + } return result.ModifiedCount, nil } -func (storage *MongoStorage) FindOne(ctx context.Context, collectionName string, - filter interface{}) (UnorderedDocument, error) { - +func (storage *MongoStorage) Update(collectionName string) UpdateOperation { collection, ok := storage.collections[collectionName] + op := MongoUpdateOperation{ + collection: collection, + filter: OrderedDocument{}, + update: OrderedDocument{{"$set", nil}}, + opt: options.Update(), + } if !ok { - return nil, errors.New("invalid collection: " + collectionName) + op.err = errors.New("invalid collection: " + collectionName) } + return op +} - if ctx == nil { - ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout) - } +// Find and FindOne - var result bson.M - err := collection.FindOne(ctx, filter).Decode(&result) - if err != nil { - if err == mongo.ErrNoDocuments { - return nil, nil - } +type FindOperation interface { + Context(ctx context.Context) FindOperation + Filter(filter OrderedDocument) FindOperation + Sort(field string, ascending bool) FindOperation + Limit(n int64) FindOperation + First(result interface{}) error + All(results interface{}) error +} - return nil, err - } +type MongoFindOperation struct { + collection *mongo.Collection + filter OrderedDocument + ctx context.Context + optFind *options.FindOptions + optFindOne *options.FindOneOptions + sorts []Entry + err error +} - return result, nil +func (fo MongoFindOperation) Context(ctx context.Context) FindOperation { + fo.ctx = ctx + return fo } -type FindOperation struct { - options options.FindOptions +func (fo MongoFindOperation) Filter(filter OrderedDocument) FindOperation { + fo.filter = filter + return fo } +func (fo MongoFindOperation) Limit(n int64) FindOperation { + fo.optFind.SetLimit(n) + return fo +} +func (fo MongoFindOperation) Sort(field string, ascending bool) FindOperation { + var sort int + if ascending { + sort = 1 + } else { + sort = -1 + } + fo.sorts = append(fo.sorts, primitive.E{Key: field, Value: sort}) + fo.optFind.SetSort(fo.sorts) + fo.optFindOne.SetSort(fo.sorts) + return fo +} -func (storage *MongoStorage) Find(ctx context.Context, collectionName string, - filter interface{}, results interface{}) error { +func (fo MongoFindOperation) First(result interface{}) error { + if fo.err != nil { + return fo.err + } - collection, ok := storage.collections[collectionName] - if !ok { - return errors.New("invalid collection: " + collectionName) + err := fo.collection.FindOne(fo.ctx, fo.filter, fo.optFindOne).Decode(result) + if err != nil { + if err == mongo.ErrNoDocuments { + result = nil + return nil + } + + return err } + return nil +} - if ctx == nil { - ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout) - } - - options.FindOptions{ - AllowDiskUse: nil, - AllowPartialResults: nil, - BatchSize: nil, - Collation: nil, - Comment: nil, - CursorType: nil, - Hint: nil, - Limit: nil, - Max: nil, - MaxAwaitTime: nil, - MaxTime: nil, - Min: nil, - NoCursorTimeout: nil, - OplogReplay: nil, - Projection: nil, - ReturnKey: nil, - ShowRecordID: nil, - Skip: nil, - Snapshot: nil, - Sort: nil, - } - cursor, err := collection.Find(ctx, filter) +func (fo MongoFindOperation) All(results interface{}) error { + if fo.err != nil { + return fo.err + } + cursor, err := fo.collection.Find(fo.ctx, fo.filter, fo.optFind) if err != nil { return err } - err = cursor.All(ctx, results) + err = cursor.All(fo.ctx, results) if err != nil { return err } - return nil } + +func (storage *MongoStorage) Find(collectionName string) FindOperation { + collection, ok := storage.collections[collectionName] + op := MongoFindOperation{ + collection: collection, + filter: OrderedDocument{}, + optFind: options.Find(), + optFindOne: options.FindOne(), + sorts: OrderedDocument{}, + } + if !ok { + op.err = errors.New("invalid collection: " + collectionName) + } + return op +} diff --git a/storage_test.go b/storage_test.go index 40440a4..e34bdb3 100644 --- a/storage_test.go +++ b/storage_test.go @@ -1,8 +1,6 @@ package main import ( - "context" - "errors" "github.com/stretchr/testify/assert" "go.mongodb.org/mongo-driver/bson/primitive" "testing" @@ -11,102 +9,177 @@ import ( type a struct { Id primitive.ObjectID `bson:"_id,omitempty"` - A string `bson:"a,omitempty"` - B int `bson:"b,omitempty"` - C time.Time `bson:"c,omitempty"` - D map[string]b `bson:"d"` - E []b `bson:"e,omitempty"` + A string `bson:"a,omitempty"` + B int `bson:"b,omitempty"` + C time.Time `bson:"c,omitempty"` + D map[string]b `bson:"d"` + E []b `bson:"e,omitempty"` } type b struct { A string `bson:"a,omitempty"` - B int `bson:"b,omitempty"` + B int `bson:"b,omitempty"` } -func testInsert(t *testing.T) { - // insert a document in an invalid connection - insertedId, err := storage.InsertOne(testContext, "invalid_collection", - OrderedDocument{{"key", "invalid"}}) - if insertedId != nil || err == nil { - t.Fatal("inserting documents in invalid collections must fail") - } +func TestOperationOnInvalidCollection(t *testing.T) { + wrapper := NewTestStorageWrapper(t) - // insert ordered document - beatriceId, err := storage.InsertOne(testContext, testCollection, - OrderedDocument{{"name", "Beatrice"}, {"description", "blablabla"}}) - if err != nil { - t.Fatal(err) - } - if beatriceId == nil { - t.Fatal("failed to insert an ordered document") - } + simpleDoc := UnorderedDocument{"key": "a", "value": 0} + insertOp := wrapper.Storage.Insert("invalid_collection").Context(wrapper.Context) + insertedId, err := insertOp.One(simpleDoc) + assert.Nil(t, insertedId) + assert.Error(t, err) - // insert unordered document - virgilioId, err := storage.InsertOne(testContext, testCollection, - UnorderedDocument{"name": "Virgilio", "description": "blablabla"}) - if err != nil { - t.Fatal(err) - } - if virgilioId == nil { - t.Fatal("failed to insert an unordered document") - } + insertedIds, err := insertOp.Many([]interface{}{simpleDoc}) + assert.Nil(t, insertedIds) + assert.Error(t, err) - // insert document with custom id - danteId := "000000" - insertedId, err = storage.InsertOne(testContext, testCollection, - UnorderedDocument{"_id": danteId, "name": "Dante Alighieri", "description": "blablabla"}) - if err != nil { - t.Fatal(err) - } - if insertedId != danteId { - t.Fatal("returned id doesn't match") - } + updateOp := wrapper.Storage.Update("invalid_collection").Context(wrapper.Context) + isUpdated, err := updateOp.One(simpleDoc) + assert.False(t, isUpdated) + assert.Error(t, err) - // insert duplicate document - insertedId, err = storage.InsertOne(testContext, testCollection, - UnorderedDocument{"_id": danteId, "name": "Dante Alighieri", "description": "blablabla"}) - if insertedId != nil || err == nil { - t.Fatal("inserting duplicate id must fail") - } + updated, err := updateOp.Many(simpleDoc) + assert.Zero(t, updated) + assert.Error(t, err) + + findOp := wrapper.Storage.Find("invalid_collection").Context(wrapper.Context) + var result interface{} + err = findOp.First(&result) + assert.Nil(t, result) + assert.Error(t, err) + + var results interface{} + err = findOp.All(&result) + assert.Nil(t, results) + assert.Error(t, err) + + wrapper.Destroy(t) } -func testFindOne(t *testing.T) { - // find a document in an invalid connection - result, err := storage.FindOne(testContext, "invalid_collection", - OrderedDocument{{"key", "invalid"}}) - if result != nil || err == nil { - t.Fatal("find a document in an invalid collections must fail") - } +func TestSimpleInsertAndFind(t *testing.T) { + wrapper := NewTestStorageWrapper(t) + collectionName := "simple_insert_find" + wrapper.AddCollection(collectionName) - // find an existing document - result, err = storage.FindOne(testContext, testCollection, OrderedDocument{{"_id", "000000"}}) - if err != nil { - t.Fatal(err) - } - if result == nil { - t.Fatal("FindOne cannot find the valid document") - } - name, ok := result["name"] - if !ok || name != "Dante Alighieri" { - t.Fatal("document retrieved with FindOne is invalid") - } + insertOp := wrapper.Storage.Insert(collectionName).Context(wrapper.Context) + simpleDocA := UnorderedDocument{"key": "a"} + idA, err := insertOp.One(simpleDocA) + assert.Len(t, idA, 12) + assert.Nil(t, err) - // find an existing document - result, err = storage.FindOne(testContext, testCollection, OrderedDocument{{"_id", "invalid_id"}}) - if err != nil { - t.Fatal(err) - } - if result != nil { - t.Fatal("FindOne cannot find an invalid document") + simpleDocB := UnorderedDocument{"_id": "idb", "key": "b"} + idB, err := insertOp.One(simpleDocB) + assert.Equal(t, "idb", idB) + assert.Nil(t, err) + + var result UnorderedDocument + findOp := wrapper.Storage.Find(collectionName).Context(wrapper.Context) + err = findOp.Filter(OrderedDocument{{"key", "a"}}).First(&result) + assert.Nil(t, err) + assert.Equal(t, idA, result["_id"]) + assert.Equal(t, simpleDocA["key"], result["key"]) + + err = findOp.Filter(OrderedDocument{{"_id", idB}}).First(&result) + assert.Nil(t, err) + assert.Equal(t, idB, result["_id"]) + assert.Equal(t, simpleDocB["key"], result["key"]) + + wrapper.Destroy(t) +} + +func TestSimpleInsertManyAndFindMany(t *testing.T) { + wrapper := NewTestStorageWrapper(t) + collectionName := "simple_insert_many_find_many" + wrapper.AddCollection(collectionName) + + insertOp := wrapper.Storage.Insert(collectionName).Context(wrapper.Context) + simpleDocs := []interface{}{ + UnorderedDocument{"key": "a"}, + UnorderedDocument{"_id": "idb", "key": "b"}, + UnorderedDocument{"key": "c"}, } + ids, err := insertOp.Many(simpleDocs) + assert.Nil(t, err) + assert.Len(t, ids, 3) + assert.Equal(t, "idb", ids[1]) + + var results []UnorderedDocument + findOp := wrapper.Storage.Find(collectionName).Context(wrapper.Context) + err = findOp.Sort("key", false).All(&results) // test sort ascending + assert.Nil(t, err) + assert.Len(t, results, 3) + assert.Equal(t, "c", results[0]["key"]) + assert.Equal(t, "b", results[1]["key"]) + assert.Equal(t, "a", results[2]["key"]) + + err = findOp.Sort("key", true).All(&results) // test sort descending + assert.Nil(t, err) + assert.Len(t, results, 3) + assert.Equal(t, "c", results[2]["key"]) + assert.Equal(t, "b", results[1]["key"]) + assert.Equal(t, "a", results[0]["key"]) + + err = findOp.Filter(OrderedDocument{{"key", OrderedDocument{{"$gte", "b"}}}}). + Sort("key", true).All(&results) // test filter + assert.Nil(t, err) + assert.Len(t, results, 2) + assert.Equal(t, "b", results[0]["key"]) + assert.Equal(t, "c", results[1]["key"]) + + wrapper.Destroy(t) } -func TestBasicOperations(t *testing.T) { - t.Run("testInsert", testInsert) - t.Run("testFindOne", testFindOne) +func TestSimpleUpdateOneUpdateMany(t *testing.T) { + wrapper := NewTestStorageWrapper(t) + collectionName := "simple_update_one_update_many" + wrapper.AddCollection(collectionName) + + insertOp := wrapper.Storage.Insert(collectionName).Context(wrapper.Context) + simpleDocs := []interface{}{ + UnorderedDocument{"_id": "ida", "key": "a"}, + UnorderedDocument{"key": "b"}, + UnorderedDocument{"key": "c"}, + } + _, err := insertOp.Many(simpleDocs) + assert.Nil(t, err) + + updateOp := wrapper.Storage.Update(collectionName).Context(wrapper.Context) + isUpdated, err := updateOp.Filter(OrderedDocument{{"_id", "ida"}}). + One(OrderedDocument{{"key", "aa"}}) + assert.Nil(t, err) + assert.True(t, isUpdated) + + updated, err := updateOp.Filter(OrderedDocument{{"key", OrderedDocument{{"$gte", "b"}}}}). + Many(OrderedDocument{{"key", "bb"}}) + assert.Nil(t, err) + assert.Equal(t, int64(2), updated) + + var upsertId interface{} + isUpdated, err = updateOp.Upsert(&upsertId).Filter(OrderedDocument{{"key", "d"}}). + One(OrderedDocument{{"key", "d"}}) + assert.Nil(t, err) + assert.False(t, isUpdated) + assert.NotNil(t, upsertId) + + var results []UnorderedDocument + findOp := wrapper.Storage.Find(collectionName).Context(wrapper.Context) + err = findOp.Sort("key", true).All(&results) // test sort ascending + assert.Nil(t, err) + assert.Len(t, results, 4) + assert.Equal(t, "aa", results[0]["key"]) + assert.Equal(t, "bb", results[1]["key"]) + assert.Equal(t, "bb", results[2]["key"]) + assert.Equal(t, "d", results[3]["key"]) + + wrapper.Destroy(t) } -func TestInsertManyFindDocuments(t *testing.T) { +func TestComplexInsertManyFindMany(t *testing.T) { + wrapper := NewTestStorageWrapper(t) + collectionName := "complex_insert_many_find_many" + wrapper.AddCollection(collectionName) + testTime := time.Now() oid1, err := primitive.ObjectIDFromHex("ffffffffffffffffffffffff") assert.Nil(t, err) @@ -117,7 +190,7 @@ func TestInsertManyFindDocuments(t *testing.T) { B: 0, C: testTime, D: map[string]b{ - "first": {A: "0", B: 0}, + "first": {A: "0", B: 0}, "second": {A: "1", B: 1}, }, E: []b{ @@ -126,22 +199,22 @@ func TestInsertManyFindDocuments(t *testing.T) { }, a{ Id: oid1, - A: "test1", - B: 1, - C: testTime, - D: map[string]b{}, - E: []b{}, + A: "test1", + B: 1, + C: testTime, + D: map[string]b{}, + E: []b{}, }, a{}, } - ids, err := storage.InsertMany(testContext, testInsertManyFindCollection, docs) + ids, err := wrapper.Storage.Insert(collectionName).Context(wrapper.Context).Many(docs) assert.Nil(t, err) assert.Len(t, ids, 3) assert.Equal(t, ids[1], oid1) var results []a - err = storage.Find(testContext, testInsertManyFindCollection, NoFilters, &results) + err = wrapper.Storage.Find(collectionName).Context(wrapper.Context).All(&results) assert.Nil(t, err) assert.Len(t, results, 3) doc0, doc1, doc2 := docs[0].(a), docs[1].(a), docs[2].(a) @@ -163,57 +236,6 @@ func TestInsertManyFindDocuments(t *testing.T) { assert.Equal(t, doc0.E, results[0].E) assert.Nil(t, results[1].E) assert.Nil(t, results[2].E) -} -type testStorage struct { - insertFunc func(ctx context.Context, collectionName string, document interface{}) (interface{}, error) - insertManyFunc func(ctx context.Context, collectionName string, document []interface{}) ([]interface{}, error) - updateOne func(ctx context.Context, collectionName string, filter interface{}, update interface {}, upsert bool) (interface{}, error) - updateMany func(ctx context.Context, collectionName string, filter interface{}, update interface {}, upsert bool) (interface{}, error) - findOne func(ctx context.Context, collectionName string, filter interface{}) (UnorderedDocument, error) - find func(ctx context.Context, collectionName string, filter interface{}, results interface{}) error -} - -func (ts testStorage) InsertOne(ctx context.Context, collectionName string, document interface{}) (interface{}, error) { - if ts.insertFunc != nil { - return ts.insertFunc(ctx, collectionName, document) - } - return nil, errors.New("not implemented") -} - -func (ts testStorage) InsertMany(ctx context.Context, collectionName string, document []interface{}) ([]interface{}, error) { - if ts.insertFunc != nil { - return ts.insertManyFunc(ctx, collectionName, document) - } - return nil, errors.New("not implemented") -} - -func (ts testStorage) UpdateOne(ctx context.Context, collectionName string, filter interface{}, update interface {}, - upsert bool) (interface{}, error) { - if ts.updateOne != nil { - return ts.updateOne(ctx, collectionName, filter, update, upsert) - } - return nil, errors.New("not implemented") -} - -func (ts testStorage) UpdateMany(ctx context.Context, collectionName string, filter interface{}, update interface {}, - upsert bool) (interface{}, error) { - if ts.updateOne != nil { - return ts.updateMany(ctx, collectionName, filter, update, upsert) - } - return nil, errors.New("not implemented") -} - -func (ts testStorage) FindOne(ctx context.Context, collectionName string, filter interface{}) (UnorderedDocument, error) { - if ts.findOne != nil { - return ts.findOne(ctx, collectionName, filter) - } - return nil, errors.New("not implemented") -} - -func (ts testStorage) Find(ctx context.Context, collectionName string, filter interface{}, results interface{}) error { - if ts.find != nil { - return ts.find(ctx, collectionName, filter, results) - } - return errors.New("not implemented") + wrapper.Destroy(t) } diff --git a/stream_handler.go b/stream_handler.go index ce580fc..3fafa21 100644 --- a/stream_handler.go +++ b/stream_handler.go @@ -2,11 +2,9 @@ package main import ( "bytes" - "encoding/binary" - "fmt" "github.com/flier/gohs/hyperscan" "github.com/google/gopacket/tcpassembly" - "log" + log "github.com/sirupsen/logrus" "time" ) @@ -28,30 +26,28 @@ type StreamHandler struct { currentIndex int firstPacketSeen time.Time lastPacketSeen time.Time - documentsKeys []string + documentsKeys []RowID streamLength int patternStream hyperscan.Stream patternMatches map[uint][]PatternSlice } -type PatternSlice [2]uint64 - // NewReaderStream returns a new StreamHandler object. func NewStreamHandler(connection ConnectionHandler, key StreamKey, scratch *hyperscan.Scratch) StreamHandler { handler := StreamHandler{ - connection: connection, - streamKey: key, - buffer: new(bytes.Buffer), - indexes: make([]int, 0, InitialBlockCount), - timestamps: make([]time.Time, 0, InitialBlockCount), - lossBlocks: make([]bool, 0, InitialBlockCount), - documentsKeys: make([]string, 0, 1), // most of the time the stream fit in one document + connection: connection, + streamKey: key, + buffer: new(bytes.Buffer), + indexes: make([]int, 0, InitialBlockCount), + timestamps: make([]time.Time, 0, InitialBlockCount), + lossBlocks: make([]bool, 0, InitialBlockCount), + documentsKeys: make([]RowID, 0, 1), // most of the time the stream fit in one document patternMatches: make(map[uint][]PatternSlice, 10), // TODO: change with exactly value } stream, err := connection.Patterns().Open(0, scratch, handler.onMatch, nil) if err != nil { - log.Println("failed to create a stream: ", err) + log.WithField("streamKey", key).WithError(err).Error("failed to create a stream") } handler.patternStream = stream @@ -81,7 +77,7 @@ func (sh *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) { } n, err := sh.buffer.Write(r.Bytes[skip:]) if err != nil { - log.Println("error while copying bytes from Reassemble in stream_handler") + log.WithError(err).Error("failed to copy bytes from a Reassemble") return } sh.indexes = append(sh.indexes, sh.currentIndex) @@ -92,7 +88,7 @@ func (sh *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) { err = sh.patternStream.Scan(r.Bytes) if err != nil { - log.Println("failed to scan packet buffer: ", err) + log.WithError(err).Error("failed to scan packet buffer") } } } @@ -101,7 +97,7 @@ func (sh *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) { func (sh *StreamHandler) ReassemblyComplete() { err := sh.patternStream.Close() if err != nil { - log.Println("failed to close pattern stream: ", err) + log.WithError(err).Error("failed to close pattern stream") } if sh.currentIndex > 0 { @@ -144,33 +140,26 @@ func (sh *StreamHandler) onMatch(id uint, from uint64, to uint64, flags uint, co } func (sh *StreamHandler) storageCurrentDocument() { - streamKey := sh.generateDocumentKey() - - _, err := sh.connection.Storage().InsertOne(sh.connection.Context(), "connection_streams", OrderedDocument{ - {"_id", streamKey}, - {"connection_id", nil}, - {"document_index", len(sh.documentsKeys)}, - {"payload", sh.buffer.Bytes()}, - {"blocks_indexes", sh.indexes}, - {"blocks_timestamps", sh.timestamps}, - {"blocks_loss", sh.lossBlocks}, - {"pattern_matches", sh.patternMatches}, - }) + payload := (sh.streamKey[0].FastHash()^sh.streamKey[1].FastHash()^sh.streamKey[2].FastHash()^ + sh.streamKey[3].FastHash())&uint64(0xffffffffffffff00) | uint64(len(sh.documentsKeys)) // LOL + streamKey := sh.connection.Storage().NewCustomRowID(payload, sh.firstPacketSeen) + + _, err := sh.connection.Storage().Insert(ConnectionStreams). + Context(sh.connection.Context()). + One(ConnectionStream{ + ID: streamKey, + ConnectionID: ZeroRowID, + DocumentIndex: len(sh.documentsKeys), + Payload: sh.buffer.Bytes(), + BlocksIndexes: sh.indexes, + BlocksTimestamps: sh.timestamps, + BlocksLoss: sh.lossBlocks, + PatternMatches: sh.patternMatches, + }) if err != nil { - log.Println("failed to insert connection stream: ", err) + log.WithError(err).Error("failed to insert connection stream") } sh.documentsKeys = append(sh.documentsKeys, streamKey) } - -func (sh *StreamHandler) generateDocumentKey() string { - hash := make([]byte, 16) - endpointsHash := sh.streamKey[0].FastHash() ^ sh.streamKey[1].FastHash() ^ - sh.streamKey[2].FastHash() ^ sh.streamKey[3].FastHash() - binary.BigEndian.PutUint64(hash, endpointsHash) - binary.BigEndian.PutUint64(hash[8:], uint64(sh.firstPacketSeen.UnixNano())) - binary.BigEndian.PutUint16(hash[8:], uint16(len(sh.documentsKeys))) - - return fmt.Sprintf("%x", hash) -} diff --git a/stream_handler_test.go b/stream_handler_test.go index cb5ecc7..ece3190 100644 --- a/stream_handler_test.go +++ b/stream_handler_test.go @@ -2,7 +2,6 @@ package main import ( "context" - "fmt" "github.com/flier/gohs/hyperscan" "github.com/google/gopacket/layers" "github.com/google/gopacket/tcpassembly" @@ -19,13 +18,14 @@ const testDstIp = "10.10.10.1" const srcPort = 44444 const dstPort = 8080 - func TestReassemblingEmptyStream(t *testing.T) { + wrapper := NewTestStorageWrapper(t) + wrapper.AddCollection(ConnectionStreams) patterns, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0)) - require.Nil(t, err) + require.NoError(t, err) scratch, err := hyperscan.NewScratch(patterns) - require.Nil(t, err) - streamHandler := createTestStreamHandler(testStorage{}, patterns, scratch) + require.NoError(t, err) + streamHandler := createTestStreamHandler(wrapper, patterns, scratch) streamHandler.Reassembled([]tcpassembly.Reassembly{{ Bytes: []byte{}, @@ -51,19 +51,20 @@ func TestReassemblingEmptyStream(t *testing.T) { assert.Equal(t, true, completed) err = scratch.Free() - require.Nil(t, err, "free scratch") + require.NoError(t, err, "free scratch") err = patterns.Close() - require.Nil(t, err, "close stream database") + require.NoError(t, err, "close stream database") + wrapper.Destroy(t) } - func TestReassemblingSingleDocument(t *testing.T) { + wrapper := NewTestStorageWrapper(t) + wrapper.AddCollection(ConnectionStreams) patterns, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/impossible_to_match/", 0)) - require.Nil(t, err) + require.NoError(t, err) scratch, err := hyperscan.NewScratch(patterns) - require.Nil(t, err) - storage := &testStorage{} - streamHandler := createTestStreamHandler(storage, patterns, scratch) + require.NoError(t, err) + streamHandler := createTestStreamHandler(wrapper, patterns, scratch) payloadLen := 256 firstTime := time.Unix(0, 0) @@ -71,10 +72,10 @@ func TestReassemblingSingleDocument(t *testing.T) { lastTime := time.Unix(20, 0) data := make([]byte, MaxDocumentSize) rand.Read(data) - reassembles := make([]tcpassembly.Reassembly, MaxDocumentSize / payloadLen) - indexes := make([]int, MaxDocumentSize / payloadLen) - timestamps := make([]time.Time, MaxDocumentSize / payloadLen) - lossBlocks := make([]bool, MaxDocumentSize / payloadLen) + reassembles := make([]tcpassembly.Reassembly, MaxDocumentSize/payloadLen) + indexes := make([]int, MaxDocumentSize/payloadLen) + timestamps := make([]time.Time, MaxDocumentSize/payloadLen) + lossBlocks := make([]bool, MaxDocumentSize/payloadLen) for i := 0; i < len(reassembles); i++ { var seen time.Time if i == 0 { @@ -86,36 +87,22 @@ func TestReassemblingSingleDocument(t *testing.T) { } reassembles[i] = tcpassembly.Reassembly{ - Bytes: data[i*payloadLen:(i+1)*payloadLen], + Bytes: data[i*payloadLen : (i+1)*payloadLen], Skip: 0, Start: i == 0, End: i == len(reassembles)-1, Seen: seen, } - indexes[i] = i*payloadLen + indexes[i] = i * payloadLen timestamps[i] = seen } - inserted := false - storage.insertFunc = func(ctx context.Context, collectionName string, document interface{}) (i interface{}, err error) { - od := document.(OrderedDocument) - assert.Equal(t, "connection_streams", collectionName) - assert.Equal(t, "bb41a60281cfae830000000000000000", od[0].Value) - assert.Equal(t, nil, od[1].Value) - assert.Equal(t, 0, od[2].Value) - assert.Equal(t, data, od[3].Value) - assert.Equal(t, indexes, od[4].Value) - assert.Equal(t, timestamps, od[5].Value) - assert.Equal(t, lossBlocks, od[6].Value) - assert.Len(t, od[7].Value, 0) - inserted = true - return nil, nil - } + var results []ConnectionStream streamHandler.Reassembled(reassembles) - if !assert.Equal(t, false, inserted) { - inserted = false - } + err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results) + require.NoError(t, err) + assert.Len(t, results, 0) completed := false streamHandler.connection.(*testConnectionHandler).onComplete = func(handler *StreamHandler) { @@ -123,6 +110,18 @@ func TestReassemblingSingleDocument(t *testing.T) { } streamHandler.ReassemblyComplete() + err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results) + require.NoError(t, err) + assert.Len(t, results, 1) + assert.Equal(t, firstTime.Unix(), results[0].ID.Timestamp().Unix()) + assert.Zero(t, results[0].ConnectionID) + assert.Equal(t, 0, results[0].DocumentIndex) + assert.Equal(t, data, results[0].Payload) + assert.Equal(t, indexes, results[0].BlocksIndexes) + assert.Len(t, results[0].BlocksTimestamps, len(timestamps)) // should be compared one by one + assert.Equal(t, lossBlocks, results[0].BlocksLoss) + assert.Len(t, results[0].PatternMatches, 0) + assert.Equal(t, len(data), streamHandler.currentIndex) assert.Equal(t, firstTime, streamHandler.firstPacketSeen) assert.Equal(t, lastTime, streamHandler.lastPacketSeen) @@ -130,35 +129,35 @@ func TestReassemblingSingleDocument(t *testing.T) { assert.Equal(t, len(data), streamHandler.streamLength) assert.Len(t, streamHandler.patternMatches, 0) - assert.Equal(t, true, inserted, "inserted") assert.Equal(t, true, completed, "completed") err = scratch.Free() - require.Nil(t, err, "free scratch") + require.NoError(t, err, "free scratch") err = patterns.Close() - require.Nil(t, err, "close stream database") + require.NoError(t, err, "close stream database") + wrapper.Destroy(t) } - func TestReassemblingMultipleDocuments(t *testing.T) { + wrapper := NewTestStorageWrapper(t) + wrapper.AddCollection(ConnectionStreams) patterns, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/impossible_to_match/", 0)) - require.Nil(t, err) + require.NoError(t, err) scratch, err := hyperscan.NewScratch(patterns) - require.Nil(t, err) - storage := &testStorage{} - streamHandler := createTestStreamHandler(storage, patterns, scratch) + require.NoError(t, err) + streamHandler := createTestStreamHandler(wrapper, patterns, scratch) payloadLen := 256 firstTime := time.Unix(0, 0) middleTime := time.Unix(10, 0) lastTime := time.Unix(20, 0) - dataSize := MaxDocumentSize*2 + dataSize := MaxDocumentSize * 2 data := make([]byte, dataSize) rand.Read(data) - reassembles := make([]tcpassembly.Reassembly, dataSize / payloadLen) - indexes := make([]int, dataSize / payloadLen) - timestamps := make([]time.Time, dataSize / payloadLen) - lossBlocks := make([]bool, dataSize / payloadLen) + reassembles := make([]tcpassembly.Reassembly, dataSize/payloadLen) + indexes := make([]int, dataSize/payloadLen) + timestamps := make([]time.Time, dataSize/payloadLen) + lossBlocks := make([]bool, dataSize/payloadLen) for i := 0; i < len(reassembles); i++ { var seen time.Time if i == 0 { @@ -170,38 +169,22 @@ func TestReassemblingMultipleDocuments(t *testing.T) { } reassembles[i] = tcpassembly.Reassembly{ - Bytes: data[i*payloadLen:(i+1)*payloadLen], + Bytes: data[i*payloadLen : (i+1)*payloadLen], Skip: 0, Start: i == 0, End: i == len(reassembles)-1, Seen: seen, } - indexes[i] = i*payloadLen % MaxDocumentSize + indexes[i] = i * payloadLen % MaxDocumentSize timestamps[i] = seen } - inserted := 0 - storage.insertFunc = func(ctx context.Context, collectionName string, document interface{}) (i interface{}, err error) { - od := document.(OrderedDocument) - blockLen := MaxDocumentSize / payloadLen - assert.Equal(t, "connection_streams", collectionName) - assert.Equal(t, fmt.Sprintf("bb41a60281cfae83000%v000000000000", inserted), od[0].Value) - assert.Equal(t, nil, od[1].Value) - assert.Equal(t, inserted, od[2].Value) - assert.Equal(t, data[MaxDocumentSize*inserted:MaxDocumentSize*(inserted+1)], od[3].Value) - assert.Equal(t, indexes[blockLen*inserted:blockLen*(inserted+1)], od[4].Value) - assert.Equal(t, timestamps[blockLen*inserted:blockLen*(inserted+1)], od[5].Value) - assert.Equal(t, lossBlocks[blockLen*inserted:blockLen*(inserted+1)], od[6].Value) - assert.Len(t, od[7].Value, 0) - inserted += 1 - - return nil, nil - } - streamHandler.Reassembled(reassembles) - if !assert.Equal(t, 1, inserted) { - inserted = 1 - } + + var results []ConnectionStream + err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results) + require.NoError(t, err) + assert.Len(t, results, 1) completed := false streamHandler.connection.(*testConnectionHandler).onComplete = func(handler *StreamHandler) { @@ -209,6 +192,21 @@ func TestReassemblingMultipleDocuments(t *testing.T) { } streamHandler.ReassemblyComplete() + err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results) + require.NoError(t, err) + assert.Len(t, results, 2) + for i := 0; i < 2; i++ { + blockLen := MaxDocumentSize / payloadLen + assert.Equal(t, firstTime.Unix(), results[i].ID.Timestamp().Unix()) + assert.Zero(t, results[i].ConnectionID) + assert.Equal(t, i, results[i].DocumentIndex) + assert.Equal(t, data[MaxDocumentSize*i:MaxDocumentSize*(i+1)], results[i].Payload) + assert.Equal(t, indexes[blockLen*i:blockLen*(i+1)], results[i].BlocksIndexes) + assert.Len(t, results[i].BlocksTimestamps, len(timestamps[blockLen*i:blockLen*(i+1)])) // should be compared one by one + assert.Equal(t, lossBlocks[blockLen*i:blockLen*(i+1)], results[i].BlocksLoss) + assert.Len(t, results[i].PatternMatches, 0) + } + assert.Equal(t, MaxDocumentSize, streamHandler.currentIndex) assert.Equal(t, firstTime, streamHandler.firstPacketSeen) assert.Equal(t, lastTime, streamHandler.lastPacketSeen) @@ -216,26 +214,28 @@ func TestReassemblingMultipleDocuments(t *testing.T) { assert.Equal(t, len(data), streamHandler.streamLength) assert.Len(t, streamHandler.patternMatches, 0) - assert.Equal(t, 2, inserted, "inserted") assert.Equal(t, true, completed, "completed") err = scratch.Free() - require.Nil(t, err, "free scratch") + require.NoError(t, err, "free scratch") err = patterns.Close() - require.Nil(t, err, "close stream database") + require.NoError(t, err, "close stream database") + wrapper.Destroy(t) } func TestReassemblingPatternMatching(t *testing.T) { + wrapper := NewTestStorageWrapper(t) + wrapper.AddCollection(ConnectionStreams) a, err := hyperscan.ParsePattern("/a{8}/i") - require.Nil(t, err) + require.NoError(t, err) a.Id = 0 a.Flags |= hyperscan.SomLeftMost b, err := hyperscan.ParsePattern("/b[c]+b/i") - require.Nil(t, err) + require.NoError(t, err) b.Id = 1 b.Flags |= hyperscan.SomLeftMost d, err := hyperscan.ParsePattern("/[d]+e[d]+/i") - require.Nil(t, err) + require.NoError(t, err) d.Id = 2 d.Flags |= hyperscan.SomLeftMost @@ -247,30 +247,12 @@ func TestReassemblingPatternMatching(t *testing.T) { } patterns, err := hyperscan.NewStreamDatabase(a, b, d) - require.Nil(t, err) + require.NoError(t, err) scratch, err := hyperscan.NewScratch(patterns) - require.Nil(t, err) - storage := &testStorage{} - streamHandler := createTestStreamHandler(storage, patterns, scratch) + require.NoError(t, err) + streamHandler := createTestStreamHandler(wrapper, patterns, scratch) seen := time.Unix(0, 0) - inserted := false - storage.insertFunc = func(ctx context.Context, collectionName string, document interface{}) (i interface{}, err error) { - od := document.(OrderedDocument) - assert.Equal(t, "connection_streams", collectionName) - assert.Equal(t, "bb41a60281cfae830000000000000000", od[0].Value) - assert.Equal(t, nil, od[1].Value) - assert.Equal(t, 0, od[2].Value) - assert.Equal(t, []byte(payload), od[3].Value) - assert.Equal(t, []int{0}, od[4].Value) - assert.Equal(t, []time.Time{seen}, od[5].Value) - assert.Equal(t, []bool{false}, od[6].Value) - assert.Equal(t, expected, od[7].Value) - inserted = true - - return nil, nil - } - streamHandler.Reassembled([]tcpassembly.Reassembly{{ Bytes: []byte(payload), Skip: 0, @@ -278,7 +260,11 @@ func TestReassemblingPatternMatching(t *testing.T) { End: true, Seen: seen, }}) - assert.Equal(t, false, inserted) + + var results []ConnectionStream + err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results) + require.NoError(t, err) + assert.Len(t, results, 0) completed := false streamHandler.connection.(*testConnectionHandler).onComplete = func(handler *StreamHandler) { @@ -286,26 +272,36 @@ func TestReassemblingPatternMatching(t *testing.T) { } streamHandler.ReassemblyComplete() + err = wrapper.Storage.Find(ConnectionStreams).Context(wrapper.Context).All(&results) + require.NoError(t, err) + assert.Len(t, results, 1) + assert.Equal(t, seen.Unix(), results[0].ID.Timestamp().Unix()) + assert.Zero(t, results[0].ConnectionID) + assert.Equal(t, 0, results[0].DocumentIndex) + assert.Equal(t, []byte(payload), results[0].Payload) + assert.Equal(t, []int{0}, results[0].BlocksIndexes) + assert.Len(t, results[0].BlocksTimestamps, 1) // should be compared one by one + assert.Equal(t, []bool{false}, results[0].BlocksLoss) + assert.Equal(t, expected, results[0].PatternMatches) + assert.Equal(t, len(payload), streamHandler.currentIndex) assert.Equal(t, seen, streamHandler.firstPacketSeen) assert.Equal(t, seen, streamHandler.lastPacketSeen) assert.Len(t, streamHandler.documentsKeys, 1) assert.Equal(t, len(payload), streamHandler.streamLength) - assert.Equal(t, true, inserted, "inserted") assert.Equal(t, true, completed, "completed") err = scratch.Free() - require.Nil(t, err, "free scratch") + require.NoError(t, err, "free scratch") err = patterns.Close() - require.Nil(t, err, "close stream database") + require.NoError(t, err, "close stream database") + wrapper.Destroy(t) } - -func createTestStreamHandler(storage Storage, patterns hyperscan.StreamDatabase, scratch *hyperscan.Scratch) StreamHandler { +func createTestStreamHandler(wrapper *TestStorageWrapper, patterns hyperscan.StreamDatabase, scratch *hyperscan.Scratch) StreamHandler { testConnectionHandler := &testConnectionHandler{ - storage: storage, - context: context.Background(), + wrapper: wrapper, patterns: patterns, } @@ -318,18 +314,17 @@ func createTestStreamHandler(storage Storage, patterns hyperscan.StreamDatabase, } type testConnectionHandler struct { - storage Storage - context context.Context - patterns hyperscan.StreamDatabase + wrapper *TestStorageWrapper + patterns hyperscan.StreamDatabase onComplete func(*StreamHandler) } func (tch *testConnectionHandler) Storage() Storage { - return tch.storage + return tch.wrapper.Storage } func (tch *testConnectionHandler) Context() context.Context { - return tch.context + return tch.wrapper.Context } func (tch *testConnectionHandler) Patterns() hyperscan.StreamDatabase { -- cgit v1.2.3-70-g09d2