aboutsummaryrefslogtreecommitdiff
path: root/storage.go
diff options
context:
space:
mode:
authorEmiliano Ciavatta2020-04-09 08:26:15 +0000
committerEmiliano Ciavatta2020-04-09 08:26:15 +0000
commit0520dab47d61e2c4de246459bf4f5c72d69182d3 (patch)
treed87df19c87a300d1022324f2ecad66380643d2f1 /storage.go
parent468690c60ee2e57ed2ccb4375e9ada5d2fed9473 (diff)
Refactor storage
Diffstat (limited to 'storage.go')
-rw-r--r--storage.go339
1 files changed, 216 insertions, 123 deletions
diff --git a/storage.go b/storage.go
index 3be56d7..b88c5c8 100644
--- a/storage.go
+++ b/storage.go
@@ -6,60 +6,55 @@ import (
"encoding/hex"
"errors"
"fmt"
- "time"
-
+ log "github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson"
+ "go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
+ "time"
)
+// Collections names
const Connections = "connections"
+const ConnectionStreams = "connection_streams"
const ImportedPcaps = "imported_pcaps"
const Rules = "rules"
-var NoFilters = UnorderedDocument{}
+const defaultConnectionTimeout = 10 * time.Second
-const defaultConnectionTimeout = 10*time.Second
-const defaultOperationTimeout = 3*time.Second
+var ZeroRowID [12]byte
type Storage interface {
- InsertOne(ctx context.Context, collectionName string, document interface{}) (interface{}, error)
- InsertMany(ctx context.Context, collectionName string, documents []interface{}) ([]interface{}, error)
- UpdateOne(ctx context.Context, collectionName string, filter interface{}, update interface {}, upsert bool) (interface{}, error)
- UpdateMany(ctx context.Context, collectionName string, filter interface{}, update interface {}, upsert bool) (interface{}, error)
- FindOne(ctx context.Context, collectionName string, filter interface{}) (UnorderedDocument, error)
- Find(ctx context.Context, collectionName string, filter interface{}, results interface{}) error
+ Insert(collectionName string) InsertOperation
+ Update(collectionName string) UpdateOperation
+ Find(collectionName string) FindOperation
+ NewCustomRowID(payload uint64, timestamp time.Time) RowID
+ NewRowID() RowID
}
type MongoStorage struct {
- client *mongo.Client
+ client *mongo.Client
collections map[string]*mongo.Collection
}
type OrderedDocument = bson.D
type UnorderedDocument = bson.M
-
-func UniqueKey(timestamp time.Time, payload uint32) string {
- var key [8]byte
- binary.BigEndian.PutUint32(key[0:4], uint32(timestamp.Unix()))
- binary.BigEndian.PutUint32(key[4:8], payload)
-
- return hex.EncodeToString(key[:])
-}
+type Entry = bson.E
+type RowID = primitive.ObjectID
func NewMongoStorage(uri string, port int, database string) *MongoStorage {
opt := options.Client()
opt.ApplyURI(fmt.Sprintf("mongodb://%s:%v", uri, port))
client, err := mongo.NewClient(opt)
if err != nil {
- panic("Failed to create mongo client")
+ log.WithError(err).Panic("failed to create mongo client")
}
db := client.Database(database)
colls := map[string]*mongo.Collection{
- Connections: db.Collection(Connections),
+ Connections: db.Collection(Connections),
ImportedPcaps: db.Collection(ImportedPcaps),
- Rules: db.Collection(Rules),
+ Rules: db.Collection(Rules),
}
return &MongoStorage{
@@ -76,19 +71,55 @@ func (storage *MongoStorage) Connect(ctx context.Context) error {
return storage.client.Connect(ctx)
}
-func (storage *MongoStorage) InsertOne(ctx context.Context, collectionName string,
- document interface{}) (interface{}, error) {
+func (storage *MongoStorage) NewCustomRowID(payload uint64, timestamp time.Time) RowID {
+ var key [12]byte
+ binary.BigEndian.PutUint32(key[0:4], uint32(timestamp.Unix()))
+ binary.BigEndian.PutUint64(key[4:12], payload)
- collection, ok := storage.collections[collectionName]
- if !ok {
- return nil, errors.New("invalid collection: " + collectionName)
+ if oid, err := primitive.ObjectIDFromHex(hex.EncodeToString(key[:])); err == nil {
+ return oid
+ } else {
+ log.WithError(err).Warn("failed to create object id")
+ return primitive.NewObjectID()
}
+}
- if ctx == nil {
- ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout)
+func (storage *MongoStorage) NewRowID() RowID {
+ return primitive.NewObjectID()
+}
+
+// InsertOne and InsertMany
+
+type InsertOperation interface {
+ Context(ctx context.Context) InsertOperation
+ StopOnFail(stop bool) InsertOperation
+ One(document interface{}) (interface{}, error)
+ Many(documents []interface{}) ([]interface{}, error)
+}
+
+type MongoInsertOperation struct {
+ collection *mongo.Collection
+ ctx context.Context
+ optInsertMany *options.InsertManyOptions
+ err error
+}
+
+func (fo MongoInsertOperation) Context(ctx context.Context) InsertOperation {
+ fo.ctx = ctx
+ return fo
+}
+
+func (fo MongoInsertOperation) StopOnFail(stop bool) InsertOperation {
+ fo.optInsertMany.SetOrdered(stop)
+ return fo
+}
+
+func (fo MongoInsertOperation) One(document interface{}) (interface{}, error) {
+ if fo.err != nil {
+ return nil, fo.err
}
- result, err := collection.InsertOne(ctx, document)
+ result, err := fo.collection.InsertOne(fo.ctx, document)
if err != nil {
return nil, err
}
@@ -96,149 +127,211 @@ func (storage *MongoStorage) InsertOne(ctx context.Context, collectionName strin
return result.InsertedID, nil
}
-func (storage *MongoStorage) InsertMany(ctx context.Context, collectionName string,
- documents []interface{}) ([]interface{}, error) {
-
- collection, ok := storage.collections[collectionName]
- if !ok {
- return nil, errors.New("invalid collection: " + collectionName)
- }
-
- if ctx == nil {
- ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout)
+func (fo MongoInsertOperation) Many(documents []interface{}) ([]interface{}, error) {
+ if fo.err != nil {
+ return nil, fo.err
}
- result, err := collection.InsertMany(ctx, documents)
+ results, err := fo.collection.InsertMany(fo.ctx, documents, fo.optInsertMany)
if err != nil {
return nil, err
}
- return result.InsertedIDs, nil
+ return results.InsertedIDs, nil
}
-func (storage *MongoStorage) UpdateOne(ctx context.Context, collectionName string,
- filter interface{}, update interface {}, upsert bool) (interface{}, error) {
-
+func (storage *MongoStorage) Insert(collectionName string) InsertOperation {
collection, ok := storage.collections[collectionName]
+ op := MongoInsertOperation{
+ collection: collection,
+ optInsertMany: options.InsertMany(),
+ }
if !ok {
- return nil, errors.New("invalid collection: " + collectionName)
+ op.err = errors.New("invalid collection: " + collectionName)
}
+ return op
+}
- if ctx == nil {
- ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout)
- }
+// UpdateOne and UpdateMany
- opts := options.Update().SetUpsert(upsert)
- update = bson.D{{"$set", update}}
+type UpdateOperation interface {
+ Context(ctx context.Context) UpdateOperation
+ Filter(filter OrderedDocument) UpdateOperation
+ Upsert(upsertResults *interface{}) UpdateOperation
+ One(update interface{}) (bool, error)
+ Many(update interface{}) (int64, error)
+}
- result, err := collection.UpdateOne(ctx, filter, update, opts)
- if err != nil {
- return nil, err
- }
+type MongoUpdateOperation struct {
+ collection *mongo.Collection
+ filter OrderedDocument
+ update OrderedDocument
+ ctx context.Context
+ opt *options.UpdateOptions
+ upsertResult *interface{}
+ err error
+}
- if upsert {
- return result.UpsertedID, nil
- }
+func (fo MongoUpdateOperation) Context(ctx context.Context) UpdateOperation {
+ fo.ctx = ctx
+ return fo
+}
- return result.ModifiedCount == 1, nil
+func (fo MongoUpdateOperation) Filter(filter OrderedDocument) UpdateOperation {
+ fo.filter = filter
+ return fo
}
-func (storage *MongoStorage) UpdateMany(ctx context.Context, collectionName string,
- filter interface{}, update interface {}, upsert bool) (interface{}, error) {
+func (fo MongoUpdateOperation) Upsert(upsertResults *interface{}) UpdateOperation {
+ fo.upsertResult = upsertResults
+ fo.opt.SetUpsert(true)
+ return fo
+}
- collection, ok := storage.collections[collectionName]
- if !ok {
- return nil, errors.New("invalid collection: " + collectionName)
+func (fo MongoUpdateOperation) One(update interface{}) (bool, error) {
+ if fo.err != nil {
+ return false, fo.err
}
- if ctx == nil {
- ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout)
+ for i := range fo.update {
+ fo.update[i].Value = update
+ }
+ result, err := fo.collection.UpdateOne(fo.ctx, fo.filter, fo.update, fo.opt)
+ if err != nil {
+ return false, err
}
- opts := options.Update().SetUpsert(upsert)
- update = bson.D{{"$set", update}}
+ if fo.upsertResult != nil {
+ *(fo.upsertResult) = result.UpsertedID
+ }
+ return result.ModifiedCount == 1, nil
+}
+
+func (fo MongoUpdateOperation) Many(update interface{}) (int64, error) {
+ if fo.err != nil {
+ return 0, fo.err
+ }
- result, err := collection.UpdateMany(ctx, filter, update, opts)
+ for i := range fo.update {
+ fo.update[i].Value = update
+ }
+ result, err := fo.collection.UpdateMany(fo.ctx, fo.filter, fo.update, fo.opt)
if err != nil {
- return nil, err
+ return 0, err
}
+ if fo.upsertResult != nil {
+ *(fo.upsertResult) = result.UpsertedID
+ }
return result.ModifiedCount, nil
}
-func (storage *MongoStorage) FindOne(ctx context.Context, collectionName string,
- filter interface{}) (UnorderedDocument, error) {
-
+func (storage *MongoStorage) Update(collectionName string) UpdateOperation {
collection, ok := storage.collections[collectionName]
+ op := MongoUpdateOperation{
+ collection: collection,
+ filter: OrderedDocument{},
+ update: OrderedDocument{{"$set", nil}},
+ opt: options.Update(),
+ }
if !ok {
- return nil, errors.New("invalid collection: " + collectionName)
+ op.err = errors.New("invalid collection: " + collectionName)
}
+ return op
+}
- if ctx == nil {
- ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout)
- }
+// Find and FindOne
- var result bson.M
- err := collection.FindOne(ctx, filter).Decode(&result)
- if err != nil {
- if err == mongo.ErrNoDocuments {
- return nil, nil
- }
+type FindOperation interface {
+ Context(ctx context.Context) FindOperation
+ Filter(filter OrderedDocument) FindOperation
+ Sort(field string, ascending bool) FindOperation
+ Limit(n int64) FindOperation
+ First(result interface{}) error
+ All(results interface{}) error
+}
- return nil, err
- }
+type MongoFindOperation struct {
+ collection *mongo.Collection
+ filter OrderedDocument
+ ctx context.Context
+ optFind *options.FindOptions
+ optFindOne *options.FindOneOptions
+ sorts []Entry
+ err error
+}
- return result, nil
+func (fo MongoFindOperation) Context(ctx context.Context) FindOperation {
+ fo.ctx = ctx
+ return fo
}
-type FindOperation struct {
- options options.FindOptions
+func (fo MongoFindOperation) Filter(filter OrderedDocument) FindOperation {
+ fo.filter = filter
+ return fo
}
+func (fo MongoFindOperation) Limit(n int64) FindOperation {
+ fo.optFind.SetLimit(n)
+ return fo
+}
+func (fo MongoFindOperation) Sort(field string, ascending bool) FindOperation {
+ var sort int
+ if ascending {
+ sort = 1
+ } else {
+ sort = -1
+ }
+ fo.sorts = append(fo.sorts, primitive.E{Key: field, Value: sort})
+ fo.optFind.SetSort(fo.sorts)
+ fo.optFindOne.SetSort(fo.sorts)
+ return fo
+}
-func (storage *MongoStorage) Find(ctx context.Context, collectionName string,
- filter interface{}, results interface{}) error {
+func (fo MongoFindOperation) First(result interface{}) error {
+ if fo.err != nil {
+ return fo.err
+ }
- collection, ok := storage.collections[collectionName]
- if !ok {
- return errors.New("invalid collection: " + collectionName)
+ err := fo.collection.FindOne(fo.ctx, fo.filter, fo.optFindOne).Decode(result)
+ if err != nil {
+ if err == mongo.ErrNoDocuments {
+ result = nil
+ return nil
+ }
+
+ return err
}
+ return nil
+}
- if ctx == nil {
- ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout)
- }
-
- options.FindOptions{
- AllowDiskUse: nil,
- AllowPartialResults: nil,
- BatchSize: nil,
- Collation: nil,
- Comment: nil,
- CursorType: nil,
- Hint: nil,
- Limit: nil,
- Max: nil,
- MaxAwaitTime: nil,
- MaxTime: nil,
- Min: nil,
- NoCursorTimeout: nil,
- OplogReplay: nil,
- Projection: nil,
- ReturnKey: nil,
- ShowRecordID: nil,
- Skip: nil,
- Snapshot: nil,
- Sort: nil,
- }
- cursor, err := collection.Find(ctx, filter)
+func (fo MongoFindOperation) All(results interface{}) error {
+ if fo.err != nil {
+ return fo.err
+ }
+ cursor, err := fo.collection.Find(fo.ctx, fo.filter, fo.optFind)
if err != nil {
return err
}
- err = cursor.All(ctx, results)
+ err = cursor.All(fo.ctx, results)
if err != nil {
return err
}
-
return nil
}
+
+func (storage *MongoStorage) Find(collectionName string) FindOperation {
+ collection, ok := storage.collections[collectionName]
+ op := MongoFindOperation{
+ collection: collection,
+ filter: OrderedDocument{},
+ optFind: options.Find(),
+ optFindOne: options.FindOne(),
+ sorts: OrderedDocument{},
+ }
+ if !ok {
+ op.err = errors.New("invalid collection: " + collectionName)
+ }
+ return op
+}