diff options
author | Emiliano Ciavatta | 2020-04-09 08:26:15 +0000 |
---|---|---|
committer | Emiliano Ciavatta | 2020-04-09 08:26:15 +0000 |
commit | 0520dab47d61e2c4de246459bf4f5c72d69182d3 (patch) | |
tree | d87df19c87a300d1022324f2ecad66380643d2f1 /storage.go | |
parent | 468690c60ee2e57ed2ccb4375e9ada5d2fed9473 (diff) |
Refactor storage
Diffstat (limited to 'storage.go')
-rw-r--r-- | storage.go | 339 |
1 files changed, 216 insertions, 123 deletions
@@ -6,60 +6,55 @@ import ( "encoding/hex" "errors" "fmt" - "time" - + log "github.com/sirupsen/logrus" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" + "time" ) +// Collections names const Connections = "connections" +const ConnectionStreams = "connection_streams" const ImportedPcaps = "imported_pcaps" const Rules = "rules" -var NoFilters = UnorderedDocument{} +const defaultConnectionTimeout = 10 * time.Second -const defaultConnectionTimeout = 10*time.Second -const defaultOperationTimeout = 3*time.Second +var ZeroRowID [12]byte type Storage interface { - InsertOne(ctx context.Context, collectionName string, document interface{}) (interface{}, error) - InsertMany(ctx context.Context, collectionName string, documents []interface{}) ([]interface{}, error) - UpdateOne(ctx context.Context, collectionName string, filter interface{}, update interface {}, upsert bool) (interface{}, error) - UpdateMany(ctx context.Context, collectionName string, filter interface{}, update interface {}, upsert bool) (interface{}, error) - FindOne(ctx context.Context, collectionName string, filter interface{}) (UnorderedDocument, error) - Find(ctx context.Context, collectionName string, filter interface{}, results interface{}) error + Insert(collectionName string) InsertOperation + Update(collectionName string) UpdateOperation + Find(collectionName string) FindOperation + NewCustomRowID(payload uint64, timestamp time.Time) RowID + NewRowID() RowID } type MongoStorage struct { - client *mongo.Client + client *mongo.Client collections map[string]*mongo.Collection } type OrderedDocument = bson.D type UnorderedDocument = bson.M - -func UniqueKey(timestamp time.Time, payload uint32) string { - var key [8]byte - binary.BigEndian.PutUint32(key[0:4], uint32(timestamp.Unix())) - binary.BigEndian.PutUint32(key[4:8], payload) - - return hex.EncodeToString(key[:]) -} +type Entry = bson.E +type RowID = primitive.ObjectID func NewMongoStorage(uri string, port int, database string) *MongoStorage { opt := options.Client() opt.ApplyURI(fmt.Sprintf("mongodb://%s:%v", uri, port)) client, err := mongo.NewClient(opt) if err != nil { - panic("Failed to create mongo client") + log.WithError(err).Panic("failed to create mongo client") } db := client.Database(database) colls := map[string]*mongo.Collection{ - Connections: db.Collection(Connections), + Connections: db.Collection(Connections), ImportedPcaps: db.Collection(ImportedPcaps), - Rules: db.Collection(Rules), + Rules: db.Collection(Rules), } return &MongoStorage{ @@ -76,19 +71,55 @@ func (storage *MongoStorage) Connect(ctx context.Context) error { return storage.client.Connect(ctx) } -func (storage *MongoStorage) InsertOne(ctx context.Context, collectionName string, - document interface{}) (interface{}, error) { +func (storage *MongoStorage) NewCustomRowID(payload uint64, timestamp time.Time) RowID { + var key [12]byte + binary.BigEndian.PutUint32(key[0:4], uint32(timestamp.Unix())) + binary.BigEndian.PutUint64(key[4:12], payload) - collection, ok := storage.collections[collectionName] - if !ok { - return nil, errors.New("invalid collection: " + collectionName) + if oid, err := primitive.ObjectIDFromHex(hex.EncodeToString(key[:])); err == nil { + return oid + } else { + log.WithError(err).Warn("failed to create object id") + return primitive.NewObjectID() } +} - if ctx == nil { - ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout) +func (storage *MongoStorage) NewRowID() RowID { + return primitive.NewObjectID() +} + +// InsertOne and InsertMany + +type InsertOperation interface { + Context(ctx context.Context) InsertOperation + StopOnFail(stop bool) InsertOperation + One(document interface{}) (interface{}, error) + Many(documents []interface{}) ([]interface{}, error) +} + +type MongoInsertOperation struct { + collection *mongo.Collection + ctx context.Context + optInsertMany *options.InsertManyOptions + err error +} + +func (fo MongoInsertOperation) Context(ctx context.Context) InsertOperation { + fo.ctx = ctx + return fo +} + +func (fo MongoInsertOperation) StopOnFail(stop bool) InsertOperation { + fo.optInsertMany.SetOrdered(stop) + return fo +} + +func (fo MongoInsertOperation) One(document interface{}) (interface{}, error) { + if fo.err != nil { + return nil, fo.err } - result, err := collection.InsertOne(ctx, document) + result, err := fo.collection.InsertOne(fo.ctx, document) if err != nil { return nil, err } @@ -96,149 +127,211 @@ func (storage *MongoStorage) InsertOne(ctx context.Context, collectionName strin return result.InsertedID, nil } -func (storage *MongoStorage) InsertMany(ctx context.Context, collectionName string, - documents []interface{}) ([]interface{}, error) { - - collection, ok := storage.collections[collectionName] - if !ok { - return nil, errors.New("invalid collection: " + collectionName) - } - - if ctx == nil { - ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout) +func (fo MongoInsertOperation) Many(documents []interface{}) ([]interface{}, error) { + if fo.err != nil { + return nil, fo.err } - result, err := collection.InsertMany(ctx, documents) + results, err := fo.collection.InsertMany(fo.ctx, documents, fo.optInsertMany) if err != nil { return nil, err } - return result.InsertedIDs, nil + return results.InsertedIDs, nil } -func (storage *MongoStorage) UpdateOne(ctx context.Context, collectionName string, - filter interface{}, update interface {}, upsert bool) (interface{}, error) { - +func (storage *MongoStorage) Insert(collectionName string) InsertOperation { collection, ok := storage.collections[collectionName] + op := MongoInsertOperation{ + collection: collection, + optInsertMany: options.InsertMany(), + } if !ok { - return nil, errors.New("invalid collection: " + collectionName) + op.err = errors.New("invalid collection: " + collectionName) } + return op +} - if ctx == nil { - ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout) - } +// UpdateOne and UpdateMany - opts := options.Update().SetUpsert(upsert) - update = bson.D{{"$set", update}} +type UpdateOperation interface { + Context(ctx context.Context) UpdateOperation + Filter(filter OrderedDocument) UpdateOperation + Upsert(upsertResults *interface{}) UpdateOperation + One(update interface{}) (bool, error) + Many(update interface{}) (int64, error) +} - result, err := collection.UpdateOne(ctx, filter, update, opts) - if err != nil { - return nil, err - } +type MongoUpdateOperation struct { + collection *mongo.Collection + filter OrderedDocument + update OrderedDocument + ctx context.Context + opt *options.UpdateOptions + upsertResult *interface{} + err error +} - if upsert { - return result.UpsertedID, nil - } +func (fo MongoUpdateOperation) Context(ctx context.Context) UpdateOperation { + fo.ctx = ctx + return fo +} - return result.ModifiedCount == 1, nil +func (fo MongoUpdateOperation) Filter(filter OrderedDocument) UpdateOperation { + fo.filter = filter + return fo } -func (storage *MongoStorage) UpdateMany(ctx context.Context, collectionName string, - filter interface{}, update interface {}, upsert bool) (interface{}, error) { +func (fo MongoUpdateOperation) Upsert(upsertResults *interface{}) UpdateOperation { + fo.upsertResult = upsertResults + fo.opt.SetUpsert(true) + return fo +} - collection, ok := storage.collections[collectionName] - if !ok { - return nil, errors.New("invalid collection: " + collectionName) +func (fo MongoUpdateOperation) One(update interface{}) (bool, error) { + if fo.err != nil { + return false, fo.err } - if ctx == nil { - ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout) + for i := range fo.update { + fo.update[i].Value = update + } + result, err := fo.collection.UpdateOne(fo.ctx, fo.filter, fo.update, fo.opt) + if err != nil { + return false, err } - opts := options.Update().SetUpsert(upsert) - update = bson.D{{"$set", update}} + if fo.upsertResult != nil { + *(fo.upsertResult) = result.UpsertedID + } + return result.ModifiedCount == 1, nil +} + +func (fo MongoUpdateOperation) Many(update interface{}) (int64, error) { + if fo.err != nil { + return 0, fo.err + } - result, err := collection.UpdateMany(ctx, filter, update, opts) + for i := range fo.update { + fo.update[i].Value = update + } + result, err := fo.collection.UpdateMany(fo.ctx, fo.filter, fo.update, fo.opt) if err != nil { - return nil, err + return 0, err } + if fo.upsertResult != nil { + *(fo.upsertResult) = result.UpsertedID + } return result.ModifiedCount, nil } -func (storage *MongoStorage) FindOne(ctx context.Context, collectionName string, - filter interface{}) (UnorderedDocument, error) { - +func (storage *MongoStorage) Update(collectionName string) UpdateOperation { collection, ok := storage.collections[collectionName] + op := MongoUpdateOperation{ + collection: collection, + filter: OrderedDocument{}, + update: OrderedDocument{{"$set", nil}}, + opt: options.Update(), + } if !ok { - return nil, errors.New("invalid collection: " + collectionName) + op.err = errors.New("invalid collection: " + collectionName) } + return op +} - if ctx == nil { - ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout) - } +// Find and FindOne - var result bson.M - err := collection.FindOne(ctx, filter).Decode(&result) - if err != nil { - if err == mongo.ErrNoDocuments { - return nil, nil - } +type FindOperation interface { + Context(ctx context.Context) FindOperation + Filter(filter OrderedDocument) FindOperation + Sort(field string, ascending bool) FindOperation + Limit(n int64) FindOperation + First(result interface{}) error + All(results interface{}) error +} - return nil, err - } +type MongoFindOperation struct { + collection *mongo.Collection + filter OrderedDocument + ctx context.Context + optFind *options.FindOptions + optFindOne *options.FindOneOptions + sorts []Entry + err error +} - return result, nil +func (fo MongoFindOperation) Context(ctx context.Context) FindOperation { + fo.ctx = ctx + return fo } -type FindOperation struct { - options options.FindOptions +func (fo MongoFindOperation) Filter(filter OrderedDocument) FindOperation { + fo.filter = filter + return fo } +func (fo MongoFindOperation) Limit(n int64) FindOperation { + fo.optFind.SetLimit(n) + return fo +} +func (fo MongoFindOperation) Sort(field string, ascending bool) FindOperation { + var sort int + if ascending { + sort = 1 + } else { + sort = -1 + } + fo.sorts = append(fo.sorts, primitive.E{Key: field, Value: sort}) + fo.optFind.SetSort(fo.sorts) + fo.optFindOne.SetSort(fo.sorts) + return fo +} -func (storage *MongoStorage) Find(ctx context.Context, collectionName string, - filter interface{}, results interface{}) error { +func (fo MongoFindOperation) First(result interface{}) error { + if fo.err != nil { + return fo.err + } - collection, ok := storage.collections[collectionName] - if !ok { - return errors.New("invalid collection: " + collectionName) + err := fo.collection.FindOne(fo.ctx, fo.filter, fo.optFindOne).Decode(result) + if err != nil { + if err == mongo.ErrNoDocuments { + result = nil + return nil + } + + return err } + return nil +} - if ctx == nil { - ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout) - } - - options.FindOptions{ - AllowDiskUse: nil, - AllowPartialResults: nil, - BatchSize: nil, - Collation: nil, - Comment: nil, - CursorType: nil, - Hint: nil, - Limit: nil, - Max: nil, - MaxAwaitTime: nil, - MaxTime: nil, - Min: nil, - NoCursorTimeout: nil, - OplogReplay: nil, - Projection: nil, - ReturnKey: nil, - ShowRecordID: nil, - Skip: nil, - Snapshot: nil, - Sort: nil, - } - cursor, err := collection.Find(ctx, filter) +func (fo MongoFindOperation) All(results interface{}) error { + if fo.err != nil { + return fo.err + } + cursor, err := fo.collection.Find(fo.ctx, fo.filter, fo.optFind) if err != nil { return err } - err = cursor.All(ctx, results) + err = cursor.All(fo.ctx, results) if err != nil { return err } - return nil } + +func (storage *MongoStorage) Find(collectionName string) FindOperation { + collection, ok := storage.collections[collectionName] + op := MongoFindOperation{ + collection: collection, + filter: OrderedDocument{}, + optFind: options.Find(), + optFindOne: options.FindOne(), + sorts: OrderedDocument{}, + } + if !ok { + op.err = errors.New("invalid collection: " + collectionName) + } + return op +} |