storagenode/notifications: db created (#3707)

This commit is contained in:
Vitalii Shpital 2019-12-16 19:59:01 +02:00 committed by GitHub
parent 11db709066
commit 53d9bc4530
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 623 additions and 115 deletions

View File

@ -16,6 +16,7 @@ import (
"github.com/zeebo/errs"
"go.uber.org/zap"
"storj.io/storj/private/dbutil"
"storj.io/storj/private/memory"
"storj.io/storj/satellite/attribution"
"storj.io/storj/satellite/satellitedb"
@ -77,7 +78,7 @@ func GenerateAttributionCSV(ctx context.Context, database string, partnerID uuid
}
func csvRowToStringSlice(p *attribution.CSVRow) ([]string, error) {
projectID, err := bytesToUUID(p.ProjectID)
projectID, err := dbutil.BytesToUUID(p.ProjectID)
if err != nil {
return nil, errs.New("Invalid Project ID")
}
@ -93,15 +94,3 @@ func csvRowToStringSlice(p *attribution.CSVRow) ([]string, error) {
}
return record, nil
}
// bytesToUUID is used to convert []byte to UUID
func bytesToUUID(data []byte) (uuid.UUID, error) {
var id uuid.UUID
copy(id[:], data)
if len(id) != len(data) {
return uuid.UUID{}, errs.New("Invalid uuid")
}
return id, nil
}

21
private/dbutil/uuid.go Normal file
View File

@ -0,0 +1,21 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package dbutil
import (
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/zeebo/errs"
)
// BytesToUUID is used to convert []byte to UUID.
func BytesToUUID(data []byte) (uuid.UUID, error) {
var id uuid.UUID
copy(id[:], data)
if len(id) != len(data) {
return uuid.UUID{}, errs.New("Invalid uuid")
}
return id, nil
}

View File

@ -0,0 +1,31 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package dbutil
import (
"testing"
"github.com/stretchr/testify/assert"
"storj.io/storj/private/testrand"
)
func TestBytesToUUID(t *testing.T) {
t.Run("Invalid input", func(t *testing.T) {
str := "not UUID string"
bytes := []byte(str)
_, err := BytesToUUID(bytes)
assert.NotNil(t, err)
assert.Error(t, err)
})
t.Run("Valid input", func(t *testing.T) {
id := testrand.UUID()
result, err := BytesToUUID(id[:])
assert.NoError(t, err)
assert.Equal(t, result, id)
})
}

View File

@ -10,9 +10,9 @@ import (
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zeebo/errs"
"storj.io/storj/pkg/pb"
"storj.io/storj/private/dbutil"
"storj.io/storj/private/testcontext"
"storj.io/storj/private/testrand"
"storj.io/storj/satellite"
@ -185,7 +185,7 @@ func verifyData(ctx *testcontext.Context, t *testing.T, attributionDB attributio
require.NotEqual(t, 0, len(results), "Results must not be empty.")
count := 0
for _, r := range results {
projectID, _ := bytesToUUID(r.ProjectID)
projectID, _ := dbutil.BytesToUUID(r.ProjectID)
// The query returns results by partnerID, so we need to filter out by projectID
if projectID != testData.projectID {
continue
@ -253,15 +253,3 @@ func createTallyData(ctx *testcontext.Context, projectAccoutingDB accounting.Pro
}
return tally, nil
}
// bytesToUUID is used to convert []byte to UUID
func bytesToUUID(data []byte) (uuid.UUID, error) {
var id uuid.UUID
copy(id[:], data)
if len(id) != len(data) {
return uuid.UUID{}, errs.New("Invalid uuid")
}
return id, nil
}

View File

@ -23,6 +23,7 @@ import (
"storj.io/storj/pkg/rpc/rpcstatus"
"storj.io/storj/pkg/signing"
"storj.io/storj/pkg/storj"
"storj.io/storj/private/dbutil"
"storj.io/storj/satellite/accounting"
"storj.io/storj/satellite/attribution"
"storj.io/storj/satellite/console"
@ -688,18 +689,6 @@ func (endpoint *Endpoint) SetAttributionOld(ctx context.Context, req *pb.SetAttr
return &pb.SetAttributionResponseOld{}, err
}
// bytesToUUID is used to convert []byte to UUID
func bytesToUUID(data []byte) (uuid.UUID, error) {
var id uuid.UUID
copy(id[:], data)
if len(id) != len(data) {
return uuid.UUID{}, errs.New("Invalid uuid")
}
return id, nil
}
// ProjectInfo returns allowed ProjectInfo for the provided API key
func (endpoint *Endpoint) ProjectInfo(ctx context.Context, req *pb.ProjectInfoRequest) (_ *pb.ProjectInfoResponse, err error) {
defer mon.Task()(&ctx)(&err)
@ -897,7 +886,7 @@ func (endpoint *Endpoint) SetBucketAttribution(ctx context.Context, req *pb.Buck
// returns empty uuid when neither is defined.
func (endpoint *Endpoint) resolvePartnerID(ctx context.Context, header *pb.RequestHeader, partnerIDBytes []byte) (uuid.UUID, error) {
if len(partnerIDBytes) > 0 {
partnerID, err := bytesToUUID(partnerIDBytes)
partnerID, err := dbutil.BytesToUUID(partnerIDBytes)
if err != nil {
return uuid.UUID{}, rpcstatus.Errorf(rpcstatus.InvalidArgument, "unable to parse partner ID: %v", err)
}

View File

@ -16,6 +16,7 @@ import (
"storj.io/storj/pkg/rpc"
"storj.io/storj/pkg/signing"
"storj.io/storj/pkg/storj"
"storj.io/storj/private/dbutil"
"storj.io/storj/satellite/console"
)
@ -87,7 +88,7 @@ func (service *Service) GetTokens(ctx context.Context, userID *uuid.UUID) (token
tokens = make([]uuid.UUID, len(tokensInBytes))
for i := range tokensInBytes {
token, err := bytesToUUID(tokensInBytes[i])
token, err := dbutil.BytesToUUID(tokensInBytes[i])
if err != nil {
service.log.Debug("failed to convert bytes to UUID", zap.Error(err))
continue
@ -185,15 +186,3 @@ func (service *Service) referralManagerConn(ctx context.Context) (*rpc.Conn, err
return service.dialer.DialAddressID(ctx, service.config.ReferralManagerURL.Address, service.config.ReferralManagerURL.ID)
}
// bytesToUUID is used to convert []byte to UUID
func bytesToUUID(data []byte) (uuid.UUID, error) {
var id uuid.UUID
copy(id[:], data)
if len(id) != len(data) {
return uuid.UUID{}, errs.New("Invalid uuid")
}
return id, nil
}

View File

@ -10,6 +10,7 @@ import (
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/zeebo/errs"
"storj.io/storj/private/dbutil"
"storj.io/storj/satellite/console"
dbx "storj.io/storj/satellite/satellitedb/dbx"
)
@ -103,7 +104,7 @@ func (keys *apikeys) GetPagedByProjectID(ctx context.Context, projectID uuid.UUI
}
if partnerIDBytes != nil {
partnerID, err = bytesToUUID(partnerIDBytes)
partnerID, err = dbutil.BytesToUUID(partnerIDBytes)
if err != nil {
return nil, err
}
@ -219,12 +220,12 @@ func (keys *apikeys) Delete(ctx context.Context, id uuid.UUID) (err error) {
// fromDBXAPIKey converts dbx.ApiKey to satellite.APIKeyInfo
func fromDBXAPIKey(ctx context.Context, key *dbx.ApiKey) (_ *console.APIKeyInfo, err error) {
defer mon.Task()(&ctx)(&err)
id, err := bytesToUUID(key.Id)
id, err := dbutil.BytesToUUID(key.Id)
if err != nil {
return nil, err
}
projectID, err := bytesToUUID(key.ProjectId)
projectID, err := dbutil.BytesToUUID(key.ProjectId)
if err != nil {
return nil, err
}
@ -238,7 +239,7 @@ func fromDBXAPIKey(ctx context.Context, key *dbx.ApiKey) (_ *console.APIKeyInfo,
}
if key.PartnerId != nil {
result.PartnerID, err = bytesToUUID(key.PartnerId)
result.PartnerID, err = dbutil.BytesToUUID(key.PartnerId)
if err != nil {
return nil, err
}

View File

@ -11,6 +11,7 @@ import (
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/zeebo/errs"
"storj.io/storj/private/dbutil"
"storj.io/storj/satellite/attribution"
dbx "storj.io/storj/satellite/satellitedb/dbx"
)
@ -169,11 +170,11 @@ func (keys *attributionDB) QueryAttribution(ctx context.Context, partnerID uuid.
}
func attributionFromDBX(info *dbx.ValueAttribution) (*attribution.Info, error) {
partnerID, err := bytesToUUID(info.PartnerId)
partnerID, err := dbutil.BytesToUUID(info.PartnerId)
if err != nil {
return nil, Error.Wrap(err)
}
projectID, err := bytesToUUID(info.ProjectId)
projectID, err := dbutil.BytesToUUID(info.ProjectId)
if err != nil {
return nil, Error.Wrap(err)
}

View File

@ -12,6 +12,7 @@ import (
"storj.io/storj/pkg/macaroon"
"storj.io/storj/pkg/storj"
"storj.io/storj/private/dbutil"
"storj.io/storj/satellite/metainfo"
dbx "storj.io/storj/satellite/satellitedb/dbx"
)
@ -188,11 +189,11 @@ func (db *bucketsDB) ListBuckets(ctx context.Context, projectID uuid.UUID, listO
}
func convertDBXtoBucket(dbxBucket *dbx.BucketMetainfo) (bucket storj.Bucket, err error) {
id, err := bytesToUUID(dbxBucket.Id)
id, err := dbutil.BytesToUUID(dbxBucket.Id)
if err != nil {
return bucket, storj.ErrBucket.Wrap(err)
}
project, err := bytesToUUID(dbxBucket.ProjectId)
project, err := dbutil.BytesToUUID(dbxBucket.ProjectId)
if err != nil {
return bucket, storj.ErrBucket.Wrap(err)
}
@ -219,7 +220,7 @@ func convertDBXtoBucket(dbxBucket *dbx.BucketMetainfo) (bucket storj.Bucket, err
}
if dbxBucket.PartnerId != nil {
partnerID, err := bytesToUUID(dbxBucket.PartnerId)
partnerID, err := dbutil.BytesToUUID(dbxBucket.PartnerId)
if err != nil {
return bucket, storj.ErrBucket.Wrap(err)
}

View File

@ -11,6 +11,7 @@ import (
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/zeebo/errs"
"storj.io/storj/private/dbutil"
"storj.io/storj/satellite/payments/coinpayments"
"storj.io/storj/satellite/payments/stripecoinpayments"
dbx "storj.io/storj/satellite/satellitedb/dbx"
@ -265,7 +266,7 @@ func (db *coinPaymentsTransactions) ListUnapplied(ctx context.Context, offset in
return stripecoinpayments.TransactionsPage{}, err
}
userID, err := bytesToUUID(userIDB)
userID, err := dbutil.BytesToUUID(userIDB)
if err != nil {
return stripecoinpayments.TransactionsPage{}, errs.Wrap(err)
}
@ -307,7 +308,7 @@ func (db *coinPaymentsTransactions) ListUnapplied(ctx context.Context, offset in
// fromDBXCoinpaymentsTransaction converts *dbx.CoinpaymentsTransaction to *stripecoinpayments.Transaction.
func fromDBXCoinpaymentsTransaction(dbxCPTX *dbx.CoinpaymentsTransaction) (*stripecoinpayments.Transaction, error) {
userID, err := bytesToUUID(dbxCPTX.UserId)
userID, err := dbutil.BytesToUUID(dbxCPTX.UserId)
if err != nil {
return nil, errs.Wrap(err)
}

View File

@ -11,6 +11,7 @@ import (
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/zeebo/errs"
"storj.io/storj/private/dbutil"
"storj.io/storj/satellite/payments"
"storj.io/storj/satellite/payments/coinpayments"
"storj.io/storj/satellite/payments/stripecoinpayments"
@ -129,17 +130,17 @@ func (coupons *coupons) ListPaged(ctx context.Context, offset int64, limit int,
// fromDBXCoupon converts *dbx.Coupon to *payments.Coupon.
func fromDBXCoupon(dbxCoupon *dbx.Coupon) (coupon payments.Coupon, err error) {
coupon.UserID, err = bytesToUUID(dbxCoupon.UserId)
coupon.UserID, err = dbutil.BytesToUUID(dbxCoupon.UserId)
if err != nil {
return payments.Coupon{}, err
}
coupon.ProjectID, err = bytesToUUID(dbxCoupon.ProjectId)
coupon.ProjectID, err = dbutil.BytesToUUID(dbxCoupon.ProjectId)
if err != nil {
return payments.Coupon{}, err
}
coupon.ID, err = bytesToUUID(dbxCoupon.Id)
coupon.ID, err = dbutil.BytesToUUID(dbxCoupon.Id)
if err != nil {
return payments.Coupon{}, err
}

View File

@ -10,6 +10,7 @@ import (
"github.com/skyrings/skyring-common/tools/uuid"
"storj.io/storj/private/dbutil"
"storj.io/storj/satellite/payments/stripecoinpayments"
dbx "storj.io/storj/satellite/satellitedb/dbx"
)
@ -89,7 +90,7 @@ func (customers *customers) List(ctx context.Context, offset int64, limit int, b
// fromDBXCustomer converts *dbx.StripeCustomer to *stripecoinpayments.Customer.
func fromDBXCustomer(dbxCustomer *dbx.StripeCustomer) (*stripecoinpayments.Customer, error) {
userID, err := bytesToUUID(dbxCustomer.UserId)
userID, err := dbutil.BytesToUUID(dbxCustomer.UserId)
if err != nil {
return nil, err
}

View File

@ -11,6 +11,7 @@ import (
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/zeebo/errs"
"storj.io/storj/private/dbutil"
"storj.io/storj/satellite/payments/stripecoinpayments"
dbx "storj.io/storj/satellite/satellitedb/dbx"
)
@ -142,11 +143,11 @@ func (db *invoiceProjectRecords) ListUnapplied(ctx context.Context, offset int64
// fromDBXInvoiceProjectRecord converts *dbx.StripecoinpaymentsInvoiceProjectRecord to *stripecoinpayments.ProjectRecord
func fromDBXInvoiceProjectRecord(dbxRecord *dbx.StripecoinpaymentsInvoiceProjectRecord) (*stripecoinpayments.ProjectRecord, error) {
id, err := bytesToUUID(dbxRecord.Id)
id, err := dbutil.BytesToUUID(dbxRecord.Id)
if err != nil {
return nil, errs.Wrap(err)
}
projectID, err := bytesToUUID(dbxRecord.ProjectId)
projectID, err := dbutil.BytesToUUID(dbxRecord.ProjectId)
if err != nil {
return nil, errs.Wrap(err)
}

View File

@ -13,6 +13,7 @@ import (
"github.com/zeebo/errs"
"storj.io/storj/pkg/pb"
"storj.io/storj/private/dbutil"
"storj.io/storj/private/memory"
"storj.io/storj/satellite/accounting"
dbx "storj.io/storj/satellite/satellitedb/dbx"
@ -66,7 +67,7 @@ func (db *ProjectAccounting) GetTallies(ctx context.Context) (tallies []accounti
}
for _, dbxTally := range dbxTallies {
projectID, err := bytesToUUID(dbxTally.ProjectId)
projectID, err := dbutil.BytesToUUID(dbxTally.ProjectId)
if err != nil {
return nil, Error.Wrap(err)
}

View File

@ -10,6 +10,7 @@ import (
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/zeebo/errs"
"storj.io/storj/private/dbutil"
"storj.io/storj/satellite/console"
dbx "storj.io/storj/satellite/satellitedb/dbx"
)
@ -124,12 +125,12 @@ func (pm *projectMembers) GetPagedByProjectID(ctx context.Context, projectID uui
return nil, err
}
memberID, err := bytesToUUID(memberIDBytes)
memberID, err := dbutil.BytesToUUID(memberIDBytes)
if err != nil {
return nil, err
}
projectID, err = bytesToUUID(projectIDBytes)
projectID, err = dbutil.BytesToUUID(projectIDBytes)
if err != nil {
return nil, err
}
@ -185,12 +186,12 @@ func projectMemberFromDBX(ctx context.Context, projectMember *dbx.ProjectMember)
return nil, errs.New("projectMember parameter is nil")
}
memberID, err := bytesToUUID(projectMember.MemberId)
memberID, err := dbutil.BytesToUUID(projectMember.MemberId)
if err != nil {
return nil, err
}
projectID, err := bytesToUUID(projectMember.ProjectId)
projectID, err := dbutil.BytesToUUID(projectMember.ProjectId)
if err != nil {
return nil, err
}

View File

@ -10,6 +10,7 @@ import (
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/zeebo/errs"
"storj.io/storj/private/dbutil"
"storj.io/storj/satellite/console"
dbx "storj.io/storj/satellite/satellitedb/dbx"
)
@ -174,20 +175,20 @@ func projectFromDBX(ctx context.Context, project *dbx.Project) (_ *console.Proje
return nil, errs.New("project parameter is nil")
}
id, err := bytesToUUID(project.Id)
id, err := dbutil.BytesToUUID(project.Id)
if err != nil {
return nil, err
}
var partnerID uuid.UUID
if len(project.PartnerId) > 0 {
partnerID, err = bytesToUUID(project.PartnerId)
partnerID, err = dbutil.BytesToUUID(project.PartnerId)
if err != nil {
return nil, err
}
}
ownerID, err := bytesToUUID(project.OwnerId)
ownerID, err := dbutil.BytesToUUID(project.OwnerId)
if err != nil {
return nil, err
}

View File

@ -9,6 +9,7 @@ import (
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/zeebo/errs"
"storj.io/storj/private/dbutil"
"storj.io/storj/satellite/console"
dbx "storj.io/storj/satellite/satellitedb/dbx"
)
@ -96,7 +97,7 @@ func registrationTokenFromDBX(ctx context.Context, regToken *dbx.RegistrationTok
}
if regToken.OwnerId != nil {
ownerID, err := bytesToUUID(regToken.OwnerId)
ownerID, err := dbutil.BytesToUUID(regToken.OwnerId)
if err != nil {
return nil, err
}

View File

@ -9,6 +9,7 @@ import (
"github.com/skyrings/skyring-common/tools/uuid"
"storj.io/storj/private/dbutil"
"storj.io/storj/satellite/console"
dbx "storj.io/storj/satellite/satellitedb/dbx"
)
@ -97,7 +98,7 @@ func resetPasswordTokenFromDBX(ctx context.Context, resetToken *dbx.ResetPasswor
}
if resetToken.OwnerId != nil {
ownerID, err := bytesToUUID(resetToken.OwnerId)
ownerID, err := dbutil.BytesToUUID(resetToken.OwnerId)
if err != nil {
return nil, err
}

View File

@ -10,6 +10,7 @@ import (
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/zeebo/errs"
"storj.io/storj/private/dbutil"
"storj.io/storj/satellite/console"
dbx "storj.io/storj/satellite/satellitedb/dbx"
)
@ -122,7 +123,7 @@ func userFromDBX(ctx context.Context, user *dbx.User) (_ *console.User, err erro
return nil, errs.New("user parameter is nil")
}
id, err := bytesToUUID(user.Id)
id, err := dbutil.BytesToUUID(user.Id)
if err != nil {
return nil, err
}
@ -137,7 +138,7 @@ func userFromDBX(ctx context.Context, user *dbx.User) (_ *console.User, err erro
}
if user.PartnerId != nil {
result.PartnerID, err = bytesToUUID(user.PartnerId)
result.PartnerID, err = dbutil.BytesToUUID(user.PartnerId)
if err != nil {
return nil, err
}

View File

@ -7,23 +7,11 @@ import (
"database/sql/driver"
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/zeebo/errs"
"storj.io/storj/pkg/storj"
"storj.io/storj/private/dbutil"
)
// bytesToUUID is used to convert []byte to UUID
func bytesToUUID(data []byte) (uuid.UUID, error) {
var id uuid.UUID
copy(id[:], data)
if len(id) != len(data) {
return uuid.UUID{}, errs.New("Invalid uuid")
}
return id, nil
}
type postgresNodeIDList storj.NodeIDList
// Value converts a NodeIDList to a postgres array
@ -74,11 +62,12 @@ type uuidScan struct {
// Scan is used to wrap logic of db scan with uuid conversion
func (s *uuidScan) Scan(src interface{}) (err error) {
b, ok := src.([]byte)
if !ok {
return Error.New("unexpected type %T for uuid", src)
}
*s.uuid, err = bytesToUUID(b)
*s.uuid, err = dbutil.BytesToUUID(b)
if err != nil {
return Error.Wrap(err)
}

View File

@ -14,25 +14,6 @@ import (
"storj.io/storj/private/testrand"
)
func TestBytesToUUID(t *testing.T) {
t.Run("Invalid input", func(t *testing.T) {
str := "not UUID string"
bytes := []byte(str)
_, err := bytesToUUID(bytes)
assert.NotNil(t, err)
assert.Error(t, err)
})
t.Run("Valid input", func(t *testing.T) {
id := testrand.UUID()
result, err := bytesToUUID(id[:])
assert.NoError(t, err)
assert.Equal(t, result, id)
})
}
func TestPostgresNodeIDsArray(t *testing.T) {
ids := make(storj.NodeIDList, 10)
for i := range ids {

View File

@ -0,0 +1,74 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package notifications
import (
"context"
"time"
"github.com/skyrings/skyring-common/tools/uuid"
"storj.io/storj/pkg/storj"
)
// DB tells how application works with notifications database.
//
// architecture: Database
type DB interface {
Insert(ctx context.Context, notification NewNotification) (Notification, error)
List(ctx context.Context, cursor NotificationCursor) (NotificationPage, error)
Read(ctx context.Context, notificationID uuid.UUID) error
ReadAll(ctx context.Context) error
}
// NotificationType is a numeric value of specific notification type.
type NotificationType int
const (
// NotificationTypeCustom is a common notification type which doesn't describe node's core functionality.
// TODO: change type name when all notification types will be known
NotificationTypeCustom NotificationType = 0
// NotificationTypeAuditCheckFailure is a notification type which describes node's audit check failure.
NotificationTypeAuditCheckFailure NotificationType = 1
// NotificationTypeUptimeCheckFailure is a notification type which describes node's uptime check failure.
NotificationTypeUptimeCheckFailure NotificationType = 2
// NotificationTypeDisqualification is a notification type which describes node's disqualification status.
NotificationTypeDisqualification NotificationType = 3
)
// NewNotification holds notification entity info which is being received from satellite or local client.
type NewNotification struct {
SenderID storj.NodeID
Type NotificationType
Title string
Message string
}
// Notification holds notification entity info which is being retrieved from database.
type Notification struct {
ID uuid.UUID
SenderID storj.NodeID
Type NotificationType
Title string
Message string
ReadAt *time.Time
CreatedAt time.Time
}
// NotificationCursor holds notification cursor entity which is used to create listed page from database.
type NotificationCursor struct {
Limit uint
Page uint
}
// NotificationPage holds notification page entity which is used to show listed page of notifications on UI.
type NotificationPage struct {
Notifications []Notification
Offset uint64
Limit uint
CurrentPage uint
PageCount uint
TotalCount uint64
}

View File

@ -0,0 +1,167 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package notifications_test
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"storj.io/storj/pkg/storj"
"storj.io/storj/private/testcontext"
"storj.io/storj/private/testidentity"
"storj.io/storj/private/testrand"
"storj.io/storj/storagenode"
"storj.io/storj/storagenode/notifications"
"storj.io/storj/storagenode/storagenodedb/storagenodedbtest"
)
func TestNotificationsDB(t *testing.T) {
storagenodedbtest.Run(t, func(t *testing.T, db storagenode.DB) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
notificationsdb := db.Notifications()
satellite0 := testidentity.MustPregeneratedSignedIdentity(0, storj.LatestIDVersion()).ID
satellite1 := testidentity.MustPregeneratedSignedIdentity(1, storj.LatestIDVersion()).ID
satellite2 := testidentity.MustPregeneratedSignedIdentity(2, storj.LatestIDVersion()).ID
expectedNotification0 := notifications.NewNotification{
SenderID: satellite0,
Type: 0,
Title: "testTitle0",
Message: "testMessage0",
}
expectedNotification1 := notifications.NewNotification{
SenderID: satellite1,
Type: 1,
Title: "testTitle1",
Message: "testMessage1",
}
expectedNotification2 := notifications.NewNotification{
SenderID: satellite2,
Type: 2,
Title: "testTitle2",
Message: "testMessage2",
}
notificationCursor := notifications.NotificationCursor{
Limit: 2,
Page: 1,
}
notificationFromDB0, err := notificationsdb.Insert(ctx, expectedNotification0)
assert.NoError(t, err)
assert.Equal(t, expectedNotification0.SenderID, notificationFromDB0.SenderID)
assert.Equal(t, expectedNotification0.Type, notificationFromDB0.Type)
assert.Equal(t, expectedNotification0.Title, notificationFromDB0.Title)
assert.Equal(t, expectedNotification0.Message, notificationFromDB0.Message)
notificationFromDB1, err := notificationsdb.Insert(ctx, expectedNotification1)
assert.NoError(t, err)
assert.Equal(t, expectedNotification1.SenderID, notificationFromDB1.SenderID)
assert.Equal(t, expectedNotification1.Type, notificationFromDB1.Type)
assert.Equal(t, expectedNotification1.Title, notificationFromDB1.Title)
assert.Equal(t, expectedNotification1.Message, notificationFromDB1.Message)
notificationFromDB2, err := notificationsdb.Insert(ctx, expectedNotification2)
assert.NoError(t, err)
assert.Equal(t, expectedNotification2.SenderID, notificationFromDB2.SenderID)
assert.Equal(t, expectedNotification2.Type, notificationFromDB2.Type)
assert.Equal(t, expectedNotification2.Title, notificationFromDB2.Title)
assert.Equal(t, expectedNotification2.Message, notificationFromDB2.Message)
page := notifications.NotificationPage{}
// test List method to return right form of page depending on cursor.
t.Run("test paged list", func(t *testing.T) {
page, err = notificationsdb.List(ctx, notificationCursor)
assert.NoError(t, err)
assert.Equal(t, 2, len(page.Notifications))
assert.Equal(t, notificationFromDB0, page.Notifications[0])
assert.Equal(t, notificationFromDB1, page.Notifications[1])
assert.Equal(t, notificationCursor.Limit, page.Limit)
assert.Equal(t, uint64(0), page.Offset)
assert.Equal(t, uint(2), page.PageCount)
assert.Equal(t, uint64(3), page.TotalCount)
assert.Equal(t, uint(1), page.CurrentPage)
})
notificationCursor = notifications.NotificationCursor{
Limit: 5,
Page: 1,
}
// test Read method to make specific notification's status as read.
t.Run("test notification read", func(t *testing.T) {
err = notificationsdb.Read(ctx, notificationFromDB0.ID)
assert.NoError(t, err)
page, err = notificationsdb.List(ctx, notificationCursor)
assert.NoError(t, err)
assert.NotEqual(t, page.Notifications[0].ReadAt, (*time.Time)(nil))
err = notificationsdb.Read(ctx, notificationFromDB1.ID)
assert.NoError(t, err)
page, err = notificationsdb.List(ctx, notificationCursor)
assert.NoError(t, err)
assert.NotEqual(t, page.Notifications[1].ReadAt, (*time.Time)(nil))
assert.Equal(t, page.Notifications[2].ReadAt, (*time.Time)(nil))
})
// test ReadAll method to make all notifications' status as read.
t.Run("test notification read all", func(t *testing.T) {
err = notificationsdb.ReadAll(ctx)
assert.NoError(t, err)
page, err = notificationsdb.List(ctx, notificationCursor)
assert.NoError(t, err)
assert.NotEqual(t, page.Notifications[0].ReadAt, (*time.Time)(nil))
assert.NotEqual(t, page.Notifications[1].ReadAt, (*time.Time)(nil))
assert.NotEqual(t, page.Notifications[2].ReadAt, (*time.Time)(nil))
})
})
}
func TestEmptyNotificationsDB(t *testing.T) {
storagenodedbtest.Run(t, func(t *testing.T, db storagenode.DB) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
notificationsdb := db.Notifications()
notificationCursor := notifications.NotificationCursor{
Limit: 5,
Page: 1,
}
// test List method to return right form of page depending on cursor with empty database.
t.Run("test empty paged list", func(t *testing.T) {
page, err := notificationsdb.List(ctx, notificationCursor)
assert.NoError(t, err)
assert.Equal(t, len(page.Notifications), 0)
assert.Equal(t, page.Limit, notificationCursor.Limit)
assert.Equal(t, page.Offset, uint64(0))
assert.Equal(t, page.PageCount, uint(0))
assert.Equal(t, page.TotalCount, uint64(0))
assert.Equal(t, page.CurrentPage, uint(0))
})
// test notification read with not existing id.
t.Run("test notification read with not existing id", func(t *testing.T) {
err := notificationsdb.Read(ctx, testrand.UUID())
assert.Error(t, err, "no rows affected")
})
// test read for all notifications if they don't exist.
t.Run("test notification readAll on empty page", func(t *testing.T) {
err := notificationsdb.ReadAll(ctx)
assert.NoError(t, err)
})
})
}

View File

@ -37,6 +37,7 @@ import (
"storj.io/storj/storagenode/inspector"
"storj.io/storj/storagenode/monitor"
"storj.io/storj/storagenode/nodestats"
"storj.io/storj/storagenode/notifications"
"storj.io/storj/storagenode/orders"
"storj.io/storj/storagenode/pieces"
"storj.io/storj/storagenode/piecestore"
@ -71,6 +72,7 @@ type DB interface {
Reputation() reputation.DB
StorageUsage() storageusage.DB
Satellites() satellites.DB
Notifications() notifications.DB
}
// Config is all the configuration parameters for a Storage Node

View File

@ -21,6 +21,7 @@ import (
"storj.io/storj/storage/filestore"
"storj.io/storj/storagenode"
"storj.io/storj/storagenode/bandwidth"
"storj.io/storj/storagenode/notifications"
"storj.io/storj/storagenode/orders"
"storj.io/storj/storagenode/pieces"
"storj.io/storj/storagenode/piecestore"
@ -37,6 +38,8 @@ var (
// ErrDatabase represents errors from the databases.
ErrDatabase = errs.Class("storage node database error")
// ErrNoRows represents database error if rows weren't affected.
ErrNoRows = errs.New("no rows affected")
)
var _ storagenode.DB = (*DB)(nil)
@ -112,6 +115,7 @@ type DB struct {
storageUsageDB *storageUsageDB
usedSerialsDB *usedSerialsDB
satellitesDB *satellitesDB
notificationsDB *notificationDB
SQLDBs map[string]DBContainer
}
@ -134,6 +138,7 @@ func New(log *zap.Logger, config Config) (*DB, error) {
storageUsageDB := &storageUsageDB{}
usedSerialsDB := &usedSerialsDB{}
satellitesDB := &satellitesDB{}
notificationsDB := &notificationDB{}
db := &DB{
log: log,
@ -153,6 +158,7 @@ func New(log *zap.Logger, config Config) (*DB, error) {
storageUsageDB: storageUsageDB,
usedSerialsDB: usedSerialsDB,
satellitesDB: satellitesDB,
notificationsDB: notificationsDB,
SQLDBs: map[string]DBContainer{
DeprecatedInfoDBName: deprecatedInfoDB,
@ -165,6 +171,7 @@ func New(log *zap.Logger, config Config) (*DB, error) {
StorageUsageDBName: storageUsageDB,
UsedSerialsDBName: usedSerialsDB,
SatellitesDBName: satellitesDB,
NotificationsDBName: notificationsDB,
},
}
@ -230,6 +237,11 @@ func (db *DB) openDatabases() error {
if err != nil {
return errs.Combine(err, db.closeDatabases())
}
err = db.openDatabase(NotificationsDBName)
if err != nil {
return errs.Combine(err, db.closeDatabases())
}
return nil
}
@ -351,6 +363,11 @@ func (db *DB) Satellites() satellites.DB {
return db.satellitesDB
}
// Notifications returns the instance of the Notifications database.
func (db *DB) Notifications() notifications.DB {
return db.notificationsDB
}
// RawDatabases are required for testing purposes
func (db *DB) RawDatabases() map[string]DBContainer {
return db.SQLDBs
@ -920,6 +937,23 @@ func (db *DB) Migration(ctx context.Context) *migrate.Migration {
`CREATE INDEX idx_order_archived_at ON order_archive_(archived_at)`,
},
},
{
DB: db.notificationsDB,
Description: "Create notifications table",
Version: 28,
Action: migrate.SQL{
`CREATE TABLE notifications (
id BLOB NOT NULL,
sender_id BLOB NOT NULL,
type INTEGER NOT NULL,
title TEXT NOT NULL,
message TEXT NOT NULL,
read_at TIMESTAMP,
created_at TIMESTAMP NOT NULL,
PRIMARY KEY (id)
);`,
},
},
},
}
}

View File

@ -0,0 +1,201 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package storagenodedb
import (
"context"
"time"
"github.com/skyrings/skyring-common/tools/uuid"
"github.com/zeebo/errs"
"storj.io/storj/private/dbutil"
"storj.io/storj/storagenode/notifications"
)
// ensures that notificationDB implements notifications.Notifications interface.
var _ notifications.DB = (*notificationDB)(nil)
// NotificationsDBName represents the database name.
const NotificationsDBName = "notifications"
// ErrNotificationsDB represents errors from the notifications database.
var ErrNotificationsDB = errs.Class("notificationsDB error")
// notificationDB is an implementation of notifications.Notifications.
//
// architecture: Database
type notificationDB struct {
dbContainerImpl
}
// Insert puts new notification to database.
func (db *notificationDB) Insert(ctx context.Context, notification notifications.NewNotification) (_ notifications.Notification, err error) {
defer mon.Task()(&ctx, notification)(&err)
id, err := uuid.New()
if err != nil {
return notifications.Notification{}, ErrNotificationsDB.Wrap(err)
}
createdAt := time.Now().UTC()
query := `
INSERT INTO
notifications (id, sender_id, type, title, message, created_at)
VALUES
(?, ?, ?, ?, ?, ?);
`
_, err = db.ExecContext(ctx, query, id[:], notification.SenderID[:], notification.Type, notification.Title, notification.Message, createdAt)
if err != nil {
return notifications.Notification{}, ErrNotificationsDB.Wrap(err)
}
return notifications.Notification{
ID: *id,
SenderID: notification.SenderID,
Type: notification.Type,
Title: notification.Title,
Message: notification.Message,
ReadAt: nil,
CreatedAt: createdAt,
}, nil
}
// List returns listed page of notifications from database.
func (db *notificationDB) List(ctx context.Context, cursor notifications.NotificationCursor) (_ notifications.NotificationPage, err error) {
defer mon.Task()(&ctx, cursor)(&err)
if cursor.Limit > 50 {
cursor.Limit = 50
}
if cursor.Page == 0 {
return notifications.NotificationPage{}, ErrNotificationsDB.Wrap(errs.New("page can not be 0"))
}
page := notifications.NotificationPage{
Limit: cursor.Limit,
Offset: uint64((cursor.Page - 1) * cursor.Limit),
}
countQuery := `
SELECT
COUNT(id)
FROM
notifications
`
err = db.QueryRowContext(ctx, countQuery).Scan(&page.TotalCount)
if err != nil {
return notifications.NotificationPage{}, ErrNotificationsDB.Wrap(err)
}
if page.TotalCount == 0 {
return page, nil
}
if page.Offset > page.TotalCount-1 {
return notifications.NotificationPage{}, ErrNotificationsDB.Wrap(errs.New("page is out of range"))
}
query := `
SELECT * FROM
notifications
ORDER BY
created_at
LIMIT ? OFFSET ?
`
rows, err := db.QueryContext(ctx, query, page.Limit, page.Offset)
if err != nil {
return notifications.NotificationPage{}, ErrNotificationsDB.Wrap(err)
}
defer func() {
err = errs.Combine(err, ErrNotificationsDB.Wrap(rows.Close()))
}()
for rows.Next() {
notification := notifications.Notification{}
var notificationIDBytes []uint8
var notificationID uuid.UUID
err = rows.Scan(
&notificationIDBytes,
&notification.SenderID,
&notification.Type,
&notification.Title,
&notification.Message,
&notification.ReadAt,
&notification.CreatedAt,
)
if err = rows.Err(); err != nil {
return notifications.NotificationPage{}, ErrNotificationsDB.Wrap(err)
}
notificationID, err = dbutil.BytesToUUID(notificationIDBytes)
if err != nil {
return notifications.NotificationPage{}, ErrNotificationsDB.Wrap(err)
}
notification.ID = notificationID
page.Notifications = append(page.Notifications, notification)
}
page.PageCount = uint(page.TotalCount / uint64(cursor.Limit))
if page.TotalCount%uint64(cursor.Limit) != 0 {
page.PageCount++
}
page.CurrentPage = cursor.Page
return page, nil
}
// Read updates specific notification in database as read.
func (db *notificationDB) Read(ctx context.Context, notificationID uuid.UUID) (err error) {
defer mon.Task()(&ctx, notificationID)(&err)
query := `
UPDATE
notifications
SET
read_at = ?
WHERE
id = ?;
`
result, err := db.ExecContext(ctx, query, time.Now().UTC(), notificationID[:])
if err != nil {
return ErrNotificationsDB.Wrap(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return ErrNotificationsDB.Wrap(err)
}
if rowsAffected != 1 {
return ErrNotificationsDB.Wrap(ErrNoRows)
}
return nil
}
// ReadAll updates all notifications in database as read.
func (db *notificationDB) ReadAll(ctx context.Context) (err error) {
defer mon.Task()(&ctx)(&err)
query := `
UPDATE
notifications
SET
read_at = ?
WHERE
read_at IS NULL;
`
_, err = db.ExecContext(ctx, query, time.Now().UTC())
return ErrNotificationsDB.Wrap(err)
}

View File

@ -41,6 +41,7 @@ var States = MultiDBStates{
&v25,
&v26,
&v27,
&v28,
},
}

View File

@ -0,0 +1,39 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package testdata
import (
"storj.io/storj/storagenode/storagenodedb"
)
var v28 = MultiDBState{
Version: 28,
DBStates: DBStates{
storagenodedb.UsedSerialsDBName: v27.DBStates[storagenodedb.UsedSerialsDBName],
storagenodedb.StorageUsageDBName: v27.DBStates[storagenodedb.StorageUsageDBName],
storagenodedb.ReputationDBName: v27.DBStates[storagenodedb.ReputationDBName],
storagenodedb.PieceSpaceUsedDBName: v27.DBStates[storagenodedb.PieceSpaceUsedDBName],
storagenodedb.PieceInfoDBName: v27.DBStates[storagenodedb.PieceInfoDBName],
storagenodedb.PieceExpirationDBName: v27.DBStates[storagenodedb.PieceExpirationDBName],
storagenodedb.OrdersDBName: v27.DBStates[storagenodedb.OrdersDBName],
storagenodedb.BandwidthDBName: v27.DBStates[storagenodedb.BandwidthDBName],
storagenodedb.SatellitesDBName: v27.DBStates[storagenodedb.SatellitesDBName],
storagenodedb.DeprecatedInfoDBName: v27.DBStates[storagenodedb.DeprecatedInfoDBName],
storagenodedb.NotificationsDBName: &DBState{
SQL: `
-- table to hold notifications data
CREATE TABLE notifications (
id BLOB NOT NULL,
sender_id BLOB NOT NULL,
type INTEGER NOT NULL,
title TEXT NOT NULL,
message TEXT NOT NULL,
read_at TIMESTAMP,
created_at TIMESTAMP NOT NULL,
PRIMARY KEY (id)
);
`,
},
},
}