From 7183dca6cbe7fcc56c8985dcce652fb0a4f07841 Mon Sep 17 00:00:00 2001 From: Egon Elbre Date: Mon, 2 Nov 2020 14:21:55 +0200 Subject: [PATCH] all: fix defers in loop defer should not be called in a loop. Change-Id: Ifa5a25a56402814b974bcdfb0c2fce56df8e7e59 --- private/dbutil/sqliteutil/query.go | 159 +++++----- private/testplanet/planet_test.go | 24 +- .../consoleweb/consoleapi/auth_test.go | 89 +++--- satellite/metainfo/metainfo_test.go | 112 +++---- satellite/orders/endpoint_test.go | 294 +++++++++--------- satellite/satellitedb/projectaccounting.go | 118 +++---- storage/testsuite/test.go | 18 +- storagenode/orders/service.go | 2 +- storagenode/piecestore/endpoint_test.go | 28 +- storagenode/piecestore/verification_test.go | 78 ++--- storagenode/storagenodedb/database.go | 192 ++++++------ 11 files changed, 578 insertions(+), 536 deletions(-) diff --git a/private/dbutil/sqliteutil/query.go b/private/dbutil/sqliteutil/query.go index 31dbdf2a7..ac4d527d8 100644 --- a/private/dbutil/sqliteutil/query.go +++ b/private/dbutil/sqliteutil/query.go @@ -73,84 +73,97 @@ func QuerySchema(ctx context.Context, db dbschema.Queryer) (*dbschema.Schema, er func discoverTables(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema, tableDefinitions []*definition) (err error) { for _, definition := range tableDefinitions { - table := schema.EnsureTable(definition.name) - - tableRows, err := db.QueryContext(ctx, `PRAGMA table_info(`+definition.name+`)`) - if err != nil { - return errs.Wrap(err) - } - defer func() { err = errs.Combine(err, tableRows.Close()) }() - - for tableRows.Next() { - var defaultValue sql.NullString - var index, name, columnType string - var pk int - var notNull bool - err := tableRows.Scan(&index, &name, &columnType, ¬Null, &defaultValue, &pk) - if err != nil { - return errs.Wrap(err) - } - - column := &dbschema.Column{ - Name: name, - Type: columnType, - IsNullable: !notNull && pk == 0, - } - table.AddColumn(column) - if pk > 0 { - if table.PrimaryKey == nil { - table.PrimaryKey = make([]string, 0) - } - table.PrimaryKey = append(table.PrimaryKey, name) - } - - } - - matches := rxUnique.FindAllStringSubmatch(definition.sql, -1) - for _, match := range matches { - // TODO feel this can be done easier - var columns []string - for _, name := range strings.Split(match[1], ",") { - columns = append(columns, strings.TrimSpace(name)) - } - - table.Unique = append(table.Unique, columns) - } - - keysRows, err := db.QueryContext(ctx, `PRAGMA foreign_key_list(`+definition.name+`)`) - if err != nil { - return errs.Wrap(err) - } - defer func() { err = errs.Combine(err, keysRows.Close()) }() - - for keysRows.Next() { - var id, sec int - var tableName, from, to, onUpdate, onDelete, match string - err := keysRows.Scan(&id, &sec, &tableName, &from, &to, &onUpdate, &onDelete, &match) - if err != nil { - return errs.Wrap(err) - } - - column, found := table.FindColumn(from) - if found { - if onDelete == "NO ACTION" { - onDelete = "" - } - if onUpdate == "NO ACTION" { - onUpdate = "" - } - column.Reference = &dbschema.Reference{ - Table: tableName, - Column: to, - OnUpdate: onUpdate, - OnDelete: onDelete, - } - } + if err := discoverTable(ctx, db, schema, definition); err != nil { + return err } } return errs.Wrap(err) } +func discoverTable(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema, definition *definition) (err error) { + table := schema.EnsureTable(definition.name) + + tableRows, err := db.QueryContext(ctx, `PRAGMA table_info(`+definition.name+`)`) + if err != nil { + return errs.Wrap(err) + } + + for tableRows.Next() { + var defaultValue sql.NullString + var index, name, columnType string + var pk int + var notNull bool + err := tableRows.Scan(&index, &name, &columnType, ¬Null, &defaultValue, &pk) + if err != nil { + return errs.Wrap(errs.Combine(tableRows.Err(), tableRows.Close(), err)) + } + + column := &dbschema.Column{ + Name: name, + Type: columnType, + IsNullable: !notNull && pk == 0, + } + table.AddColumn(column) + if pk > 0 { + if table.PrimaryKey == nil { + table.PrimaryKey = make([]string, 0) + } + table.PrimaryKey = append(table.PrimaryKey, name) + } + } + err = errs.Combine(tableRows.Err(), tableRows.Close()) + if err != nil { + return errs.Wrap(err) + } + + matches := rxUnique.FindAllStringSubmatch(definition.sql, -1) + for _, match := range matches { + // TODO feel this can be done easier + var columns []string + for _, name := range strings.Split(match[1], ",") { + columns = append(columns, strings.TrimSpace(name)) + } + + table.Unique = append(table.Unique, columns) + } + + keysRows, err := db.QueryContext(ctx, `PRAGMA foreign_key_list(`+definition.name+`)`) + if err != nil { + return errs.Wrap(err) + } + + for keysRows.Next() { + var id, sec int + var tableName, from, to, onUpdate, onDelete, match string + err := keysRows.Scan(&id, &sec, &tableName, &from, &to, &onUpdate, &onDelete, &match) + if err != nil { + return errs.Wrap(errs.Combine(keysRows.Err(), keysRows.Close(), err)) + } + + column, found := table.FindColumn(from) + if found { + if onDelete == "NO ACTION" { + onDelete = "" + } + if onUpdate == "NO ACTION" { + onUpdate = "" + } + column.Reference = &dbschema.Reference{ + Table: tableName, + Column: to, + OnUpdate: onUpdate, + OnDelete: onDelete, + } + } + } + err = errs.Combine(keysRows.Err(), keysRows.Close()) + if err != nil { + return errs.Wrap(err) + } + + return nil +} + func discoverIndexes(ctx context.Context, db dbschema.Queryer, schema *dbschema.Schema, indexDefinitions []*definition) (err error) { // TODO improve indexes discovery for _, definition := range indexDefinitions { diff --git a/private/testplanet/planet_test.go b/private/testplanet/planet_test.go index 006ee1a9d..a0b2d01b8 100644 --- a/private/testplanet/planet_test.go +++ b/private/testplanet/planet_test.go @@ -34,18 +34,20 @@ func TestBasic(t *testing.T) { for _, sat := range planet.Satellites { for _, sn := range planet.StorageNodes { - node := sn.Contact.Service.Local() - conn, err := sn.Dialer.DialNodeURL(ctx, sat.NodeURL()) + func() { + node := sn.Contact.Service.Local() + conn, err := sn.Dialer.DialNodeURL(ctx, sat.NodeURL()) - require.NoError(t, err) - defer ctx.Check(conn.Close) - _, err = pb.NewDRPCNodeClient(conn).CheckIn(ctx, &pb.CheckInRequest{ - Address: node.Address, - Version: &node.Version, - Capacity: &node.Capacity, - Operator: &node.Operator, - }) - require.NoError(t, err) + require.NoError(t, err) + defer ctx.Check(conn.Close) + _, err = pb.NewDRPCNodeClient(conn).CheckIn(ctx, &pb.CheckInRequest{ + Address: node.Address, + Version: &node.Version, + Capacity: &node.Capacity, + Operator: &node.Operator, + }) + require.NoError(t, err) + }() } } // wait a bit to see whether some failures occur diff --git a/satellite/console/consoleweb/consoleapi/auth_test.go b/satellite/console/consoleweb/consoleapi/auth_test.go index b7845675f..3005a8930 100644 --- a/satellite/console/consoleweb/consoleapi/auth_test.go +++ b/satellite/console/consoleweb/consoleapi/auth_test.go @@ -39,7 +39,6 @@ func TestAuth_Register(t *testing.T) { }, }, }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { - for i, test := range []struct { Partner string ValidPartner bool @@ -50,52 +49,54 @@ func TestAuth_Register(t *testing.T) { {Partner: "Raiden nEtwork", ValidPartner: true}, {Partner: "invalid-name", ValidPartner: false}, } { - registerData := struct { - FullName string `json:"fullName"` - ShortName string `json:"shortName"` - Email string `json:"email"` - Partner string `json:"partner"` - PartnerID string `json:"partnerId"` - Password string `json:"password"` - SecretInput string `json:"secret"` - ReferrerUserID string `json:"referrerUserId"` - }{ - FullName: "testuser" + strconv.Itoa(i), - ShortName: "test", - Email: "user@test" + strconv.Itoa(i), - Partner: test.Partner, - Password: "abc123", - } + func() { + registerData := struct { + FullName string `json:"fullName"` + ShortName string `json:"shortName"` + Email string `json:"email"` + Partner string `json:"partner"` + PartnerID string `json:"partnerId"` + Password string `json:"password"` + SecretInput string `json:"secret"` + ReferrerUserID string `json:"referrerUserId"` + }{ + FullName: "testuser" + strconv.Itoa(i), + ShortName: "test", + Email: "user@test" + strconv.Itoa(i), + Partner: test.Partner, + Password: "abc123", + } - jsonBody, err := json.Marshal(registerData) - require.NoError(t, err) - - result, err := http.Post("http://"+planet.Satellites[0].API.Console.Listener.Addr().String()+"/api/v0/auth/register", "application/json", bytes.NewBuffer(jsonBody)) - require.NoError(t, err) - require.Equal(t, http.StatusOK, result.StatusCode) - - defer func() { - err = result.Body.Close() + jsonBody, err := json.Marshal(registerData) require.NoError(t, err) + + result, err := http.Post("http://"+planet.Satellites[0].API.Console.Listener.Addr().String()+"/api/v0/auth/register", "application/json", bytes.NewBuffer(jsonBody)) + require.NoError(t, err) + require.Equal(t, http.StatusOK, result.StatusCode) + + defer func() { + err = result.Body.Close() + require.NoError(t, err) + }() + + body, err := ioutil.ReadAll(result.Body) + require.NoError(t, err) + + var userID uuid.UUID + err = json.Unmarshal(body, &userID) + require.NoError(t, err) + + user, err := planet.Satellites[0].API.Console.Service.GetUser(ctx, userID) + require.NoError(t, err) + + if test.ValidPartner { + info, err := planet.Satellites[0].API.Marketing.PartnersService.ByName(ctx, test.Partner) + require.NoError(t, err) + require.Equal(t, info.UUID, user.PartnerID) + } else { + require.Equal(t, uuid.UUID{}, user.PartnerID) + } }() - - body, err := ioutil.ReadAll(result.Body) - require.NoError(t, err) - - var userID uuid.UUID - err = json.Unmarshal(body, &userID) - require.NoError(t, err) - - user, err := planet.Satellites[0].API.Console.Service.GetUser(ctx, userID) - require.NoError(t, err) - - if test.ValidPartner { - info, err := planet.Satellites[0].API.Marketing.PartnersService.ByName(ctx, test.Partner) - require.NoError(t, err) - require.Equal(t, info.UUID, user.PartnerID) - } else { - require.Equal(t, uuid.UUID{}, user.PartnerID) - } } }) } diff --git a/satellite/metainfo/metainfo_test.go b/satellite/metainfo/metainfo_test.go index 52addae76..994fd64af 100644 --- a/satellite/metainfo/metainfo_test.go +++ b/satellite/metainfo/metainfo_test.go @@ -215,85 +215,87 @@ func TestInvalidAPIKey(t *testing.T) { require.NoError(t, err) for _, invalidAPIKey := range []string{"", "invalid", "testKey"} { - client, err := planet.Uplinks[0].DialMetainfo(ctx, planet.Satellites[0], throwawayKey) - require.NoError(t, err) - defer ctx.Check(client.Close) + func() { + client, err := planet.Uplinks[0].DialMetainfo(ctx, planet.Satellites[0], throwawayKey) + require.NoError(t, err) + defer ctx.Check(client.Close) - client.SetRawAPIKey([]byte(invalidAPIKey)) + client.SetRawAPIKey([]byte(invalidAPIKey)) - _, err = client.BeginObject(ctx, metainfo.BeginObjectParams{}) - assertInvalidArgument(t, err, false) + _, err = client.BeginObject(ctx, metainfo.BeginObjectParams{}) + assertInvalidArgument(t, err, false) - _, err = client.BeginDeleteObject(ctx, metainfo.BeginDeleteObjectParams{}) - assertInvalidArgument(t, err, false) + _, err = client.BeginDeleteObject(ctx, metainfo.BeginDeleteObjectParams{}) + assertInvalidArgument(t, err, false) - _, err = client.ListBuckets(ctx, metainfo.ListBucketsParams{}) - assertInvalidArgument(t, err, false) + _, err = client.ListBuckets(ctx, metainfo.ListBucketsParams{}) + assertInvalidArgument(t, err, false) - _, _, err = client.ListObjects(ctx, metainfo.ListObjectsParams{}) - assertInvalidArgument(t, err, false) + _, _, err = client.ListObjects(ctx, metainfo.ListObjectsParams{}) + assertInvalidArgument(t, err, false) - _, err = client.CreateBucket(ctx, metainfo.CreateBucketParams{}) - assertInvalidArgument(t, err, false) + _, err = client.CreateBucket(ctx, metainfo.CreateBucketParams{}) + assertInvalidArgument(t, err, false) - _, err = client.DeleteBucket(ctx, metainfo.DeleteBucketParams{}) - assertInvalidArgument(t, err, false) + _, err = client.DeleteBucket(ctx, metainfo.DeleteBucketParams{}) + assertInvalidArgument(t, err, false) - _, err = client.BeginDeleteObject(ctx, metainfo.BeginDeleteObjectParams{}) - assertInvalidArgument(t, err, false) + _, err = client.BeginDeleteObject(ctx, metainfo.BeginDeleteObjectParams{}) + assertInvalidArgument(t, err, false) - _, err = client.GetBucket(ctx, metainfo.GetBucketParams{}) - assertInvalidArgument(t, err, false) + _, err = client.GetBucket(ctx, metainfo.GetBucketParams{}) + assertInvalidArgument(t, err, false) - _, err = client.GetObject(ctx, metainfo.GetObjectParams{}) - assertInvalidArgument(t, err, false) + _, err = client.GetObject(ctx, metainfo.GetObjectParams{}) + assertInvalidArgument(t, err, false) - _, err = client.GetProjectInfo(ctx) - assertInvalidArgument(t, err, false) + _, err = client.GetProjectInfo(ctx) + assertInvalidArgument(t, err, false) - // these methods needs StreamID to do authentication + // these methods needs StreamID to do authentication - signer := signing.SignerFromFullIdentity(planet.Satellites[0].Identity) - satStreamID := &internalpb.StreamID{ - CreationDate: time.Now(), - } - signedStreamID, err := satMetainfo.SignStreamID(ctx, signer, satStreamID) - require.NoError(t, err) + signer := signing.SignerFromFullIdentity(planet.Satellites[0].Identity) + satStreamID := &internalpb.StreamID{ + CreationDate: time.Now(), + } + signedStreamID, err := satMetainfo.SignStreamID(ctx, signer, satStreamID) + require.NoError(t, err) - encodedStreamID, err := pb.Marshal(signedStreamID) - require.NoError(t, err) + encodedStreamID, err := pb.Marshal(signedStreamID) + require.NoError(t, err) - streamID, err := storj.StreamIDFromBytes(encodedStreamID) - require.NoError(t, err) + streamID, err := storj.StreamIDFromBytes(encodedStreamID) + require.NoError(t, err) - err = client.CommitObject(ctx, metainfo.CommitObjectParams{StreamID: streamID}) - assertInvalidArgument(t, err, false) + err = client.CommitObject(ctx, metainfo.CommitObjectParams{StreamID: streamID}) + assertInvalidArgument(t, err, false) - _, _, _, err = client.BeginSegment(ctx, metainfo.BeginSegmentParams{StreamID: streamID}) - assertInvalidArgument(t, err, false) + _, _, _, err = client.BeginSegment(ctx, metainfo.BeginSegmentParams{StreamID: streamID}) + assertInvalidArgument(t, err, false) - err = client.MakeInlineSegment(ctx, metainfo.MakeInlineSegmentParams{StreamID: streamID}) - assertInvalidArgument(t, err, false) + err = client.MakeInlineSegment(ctx, metainfo.MakeInlineSegmentParams{StreamID: streamID}) + assertInvalidArgument(t, err, false) - _, _, err = client.DownloadSegment(ctx, metainfo.DownloadSegmentParams{StreamID: streamID}) - assertInvalidArgument(t, err, false) + _, _, err = client.DownloadSegment(ctx, metainfo.DownloadSegmentParams{StreamID: streamID}) + assertInvalidArgument(t, err, false) - // these methods needs SegmentID + // these methods needs SegmentID - signedSegmentID, err := satMetainfo.SignSegmentID(ctx, signer, &internalpb.SegmentID{ - StreamId: satStreamID, - CreationDate: time.Now(), - }) - require.NoError(t, err) + signedSegmentID, err := satMetainfo.SignSegmentID(ctx, signer, &internalpb.SegmentID{ + StreamId: satStreamID, + CreationDate: time.Now(), + }) + require.NoError(t, err) - encodedSegmentID, err := pb.Marshal(signedSegmentID) - require.NoError(t, err) + encodedSegmentID, err := pb.Marshal(signedSegmentID) + require.NoError(t, err) - segmentID, err := storj.SegmentIDFromBytes(encodedSegmentID) - require.NoError(t, err) + segmentID, err := storj.SegmentIDFromBytes(encodedSegmentID) + require.NoError(t, err) - err = client.CommitSegment(ctx, metainfo.CommitSegmentParams{SegmentID: segmentID}) - assertInvalidArgument(t, err, false) + err = client.CommitSegment(ctx, metainfo.CommitSegmentParams{SegmentID: segmentID}) + assertInvalidArgument(t, err, false) + }() } }) } diff --git a/satellite/orders/endpoint_test.go b/satellite/orders/endpoint_test.go index 256ec0911..9ca8ad622 100644 --- a/satellite/orders/endpoint_test.go +++ b/satellite/orders/endpoint_test.go @@ -76,106 +76,108 @@ func TestSettlementWithWindowEndpointManyOrders(t *testing.T) { } for _, tt := range testCases { - // create serial number to use in test. must be unique for each run. - serialNumber1 := testrand.SerialNumber() - err = ordersDB.CreateSerialInfo(ctx, serialNumber1, []byte(bucketID), now.AddDate(1, 0, 10)) - require.NoError(t, err) + func() { + // create serial number to use in test. must be unique for each run. + serialNumber1 := testrand.SerialNumber() + err = ordersDB.CreateSerialInfo(ctx, serialNumber1, []byte(bucketID), now.AddDate(1, 0, 10)) + require.NoError(t, err) - serialNumber2 := testrand.SerialNumber() - err = ordersDB.CreateSerialInfo(ctx, serialNumber2, []byte(bucketID), now.AddDate(1, 0, 10)) - require.NoError(t, err) + serialNumber2 := testrand.SerialNumber() + err = ordersDB.CreateSerialInfo(ctx, serialNumber2, []byte(bucketID), now.AddDate(1, 0, 10)) + require.NoError(t, err) - piecePublicKey, piecePrivateKey, err := storj.NewPieceKey() - require.NoError(t, err) + piecePublicKey, piecePrivateKey, err := storj.NewPieceKey() + require.NoError(t, err) - // create signed orderlimit or order to test with - limit1 := &pb.OrderLimit{ - SerialNumber: serialNumber1, - SatelliteId: satellite.ID(), - UplinkPublicKey: piecePublicKey, - StorageNodeId: storagenode.ID(), - PieceId: storj.NewPieceID(), - Action: pb.PieceAction_PUT, - Limit: 1000, - PieceExpiration: time.Time{}, - OrderCreation: tt.orderCreation, - OrderExpiration: now.Add(24 * time.Hour), - } - orderLimit1, err := signing.SignOrderLimit(ctx, signing.SignerFromFullIdentity(satellite.Identity), limit1) - require.NoError(t, err) + // create signed orderlimit or order to test with + limit1 := &pb.OrderLimit{ + SerialNumber: serialNumber1, + SatelliteId: satellite.ID(), + UplinkPublicKey: piecePublicKey, + StorageNodeId: storagenode.ID(), + PieceId: storj.NewPieceID(), + Action: pb.PieceAction_PUT, + Limit: 1000, + PieceExpiration: time.Time{}, + OrderCreation: tt.orderCreation, + OrderExpiration: now.Add(24 * time.Hour), + } + orderLimit1, err := signing.SignOrderLimit(ctx, signing.SignerFromFullIdentity(satellite.Identity), limit1) + require.NoError(t, err) - order1, err := signing.SignUplinkOrder(ctx, piecePrivateKey, &pb.Order{ - SerialNumber: serialNumber1, - Amount: tt.dataAmount, - }) - require.NoError(t, err) + order1, err := signing.SignUplinkOrder(ctx, piecePrivateKey, &pb.Order{ + SerialNumber: serialNumber1, + Amount: tt.dataAmount, + }) + require.NoError(t, err) - limit2 := &pb.OrderLimit{ - SerialNumber: serialNumber2, - SatelliteId: satellite.ID(), - UplinkPublicKey: piecePublicKey, - StorageNodeId: storagenode.ID(), - PieceId: storj.NewPieceID(), - Action: pb.PieceAction_PUT, - Limit: 1000, - PieceExpiration: time.Time{}, - OrderCreation: now, - OrderExpiration: now.Add(24 * time.Hour), - } - orderLimit2, err := signing.SignOrderLimit(ctx, signing.SignerFromFullIdentity(satellite.Identity), limit2) - require.NoError(t, err) + limit2 := &pb.OrderLimit{ + SerialNumber: serialNumber2, + SatelliteId: satellite.ID(), + UplinkPublicKey: piecePublicKey, + StorageNodeId: storagenode.ID(), + PieceId: storj.NewPieceID(), + Action: pb.PieceAction_PUT, + Limit: 1000, + PieceExpiration: time.Time{}, + OrderCreation: now, + OrderExpiration: now.Add(24 * time.Hour), + } + orderLimit2, err := signing.SignOrderLimit(ctx, signing.SignerFromFullIdentity(satellite.Identity), limit2) + require.NoError(t, err) - order2, err := signing.SignUplinkOrder(ctx, piecePrivateKey, &pb.Order{ - SerialNumber: serialNumber2, - Amount: tt.dataAmount, - }) - require.NoError(t, err) + order2, err := signing.SignUplinkOrder(ctx, piecePrivateKey, &pb.Order{ + SerialNumber: serialNumber2, + Amount: tt.dataAmount, + }) + require.NoError(t, err) - // create connection between storagenode and satellite - conn, err := storagenode.Dialer.DialNodeURL(ctx, storj.NodeURL{ID: satellite.ID(), Address: satellite.Addr()}) - require.NoError(t, err) - defer ctx.Check(conn.Close) + // create connection between storagenode and satellite + conn, err := storagenode.Dialer.DialNodeURL(ctx, storj.NodeURL{ID: satellite.ID(), Address: satellite.Addr()}) + require.NoError(t, err) + defer ctx.Check(conn.Close) - stream, err := pb.NewDRPCOrdersClient(conn).SettlementWithWindow(ctx) - require.NoError(t, err) - defer ctx.Check(stream.Close) + stream, err := pb.NewDRPCOrdersClient(conn).SettlementWithWindow(ctx) + require.NoError(t, err) + defer ctx.Check(stream.Close) - // storagenode settles an order and orderlimit - err = stream.Send(&pb.SettlementRequest{ - Limit: orderLimit1, - Order: order1, - }) - require.NoError(t, err) - err = stream.Send(&pb.SettlementRequest{ - Limit: orderLimit2, - Order: order2, - }) - require.NoError(t, err) - resp, err := stream.CloseAndRecv() - require.NoError(t, err) + // storagenode settles an order and orderlimit + err = stream.Send(&pb.SettlementRequest{ + Limit: orderLimit1, + Order: order1, + }) + require.NoError(t, err) + err = stream.Send(&pb.SettlementRequest{ + Limit: orderLimit2, + Order: order2, + }) + require.NoError(t, err) + resp, err := stream.CloseAndRecv() + require.NoError(t, err) - // the settled amount is only returned during phase3 - var settled map[int32]int64 - if satellite.Config.Orders.WindowEndpointRolloutPhase == orders.WindowEndpointRolloutPhase3 { - settled = map[int32]int64{int32(pb.PieceAction_PUT): tt.settledAmt} - } - require.Equal(t, &pb.SettlementWithWindowResponse{ - Status: pb.SettlementWithWindowResponse_ACCEPTED, - ActionSettled: settled, - }, resp) + // the settled amount is only returned during phase3 + var settled map[int32]int64 + if satellite.Config.Orders.WindowEndpointRolloutPhase == orders.WindowEndpointRolloutPhase3 { + settled = map[int32]int64{int32(pb.PieceAction_PUT): tt.settledAmt} + } + require.Equal(t, &pb.SettlementWithWindowResponse{ + Status: pb.SettlementWithWindowResponse_ACCEPTED, + ActionSettled: settled, + }, resp) - // trigger and wait for all of the chores necessary to flush the orders - assert.NoError(t, satellite.Accounting.ReportedRollup.RunOnce(ctx, tt.orderCreation)) - satellite.Orders.Chore.Loop.TriggerWait() + // trigger and wait for all of the chores necessary to flush the orders + assert.NoError(t, satellite.Accounting.ReportedRollup.RunOnce(ctx, tt.orderCreation)) + satellite.Orders.Chore.Loop.TriggerWait() - // assert all the right stuff is in the satellite storagenode and bucket bandwidth tables - snbw, err = ordersDB.GetStorageNodeBandwidth(ctx, storagenode.ID(), time.Time{}, tt.orderCreation) - require.NoError(t, err) - require.EqualValues(t, tt.settledAmt, snbw) + // assert all the right stuff is in the satellite storagenode and bucket bandwidth tables + snbw, err = ordersDB.GetStorageNodeBandwidth(ctx, storagenode.ID(), time.Time{}, tt.orderCreation) + require.NoError(t, err) + require.EqualValues(t, tt.settledAmt, snbw) - newBbw, err := ordersDB.GetBucketBandwidth(ctx, projectID, []byte(bucketname), time.Time{}, tt.orderCreation) - require.NoError(t, err) - require.EqualValues(t, tt.settledAmt, newBbw) + newBbw, err := ordersDB.GetBucketBandwidth(ctx, projectID, []byte(bucketname), time.Time{}, tt.orderCreation) + require.NoError(t, err) + require.EqualValues(t, tt.settledAmt, newBbw) + }() } }) } @@ -223,72 +225,74 @@ func TestSettlementWithWindowEndpointSingleOrder(t *testing.T) { } for _, tt := range testCases { - // create signed orderlimit or order to test with - limit := &pb.OrderLimit{ - SerialNumber: serialNumber, - SatelliteId: satellite.ID(), - UplinkPublicKey: piecePublicKey, - StorageNodeId: storagenode.ID(), - PieceId: storj.NewPieceID(), - Action: pb.PieceAction_PUT, - Limit: 1000, - PieceExpiration: time.Time{}, - OrderCreation: now, - OrderExpiration: now.Add(24 * time.Hour), - } - orderLimit, err := signing.SignOrderLimit(ctx, signing.SignerFromFullIdentity(satellite.Identity), limit) - require.NoError(t, err) + func() { + // create signed orderlimit or order to test with + limit := &pb.OrderLimit{ + SerialNumber: serialNumber, + SatelliteId: satellite.ID(), + UplinkPublicKey: piecePublicKey, + StorageNodeId: storagenode.ID(), + PieceId: storj.NewPieceID(), + Action: pb.PieceAction_PUT, + Limit: 1000, + PieceExpiration: time.Time{}, + OrderCreation: now, + OrderExpiration: now.Add(24 * time.Hour), + } + orderLimit, err := signing.SignOrderLimit(ctx, signing.SignerFromFullIdentity(satellite.Identity), limit) + require.NoError(t, err) - order, err := signing.SignUplinkOrder(ctx, piecePrivateKey, &pb.Order{ - SerialNumber: serialNumber, - Amount: tt.dataAmount, - }) - require.NoError(t, err) + order, err := signing.SignUplinkOrder(ctx, piecePrivateKey, &pb.Order{ + SerialNumber: serialNumber, + Amount: tt.dataAmount, + }) + require.NoError(t, err) - // create connection between storagenode and satellite - conn, err := storagenode.Dialer.DialNodeURL(ctx, storj.NodeURL{ID: satellite.ID(), Address: satellite.Addr()}) - require.NoError(t, err) - defer ctx.Check(conn.Close) + // create connection between storagenode and satellite + conn, err := storagenode.Dialer.DialNodeURL(ctx, storj.NodeURL{ID: satellite.ID(), Address: satellite.Addr()}) + require.NoError(t, err) + defer ctx.Check(conn.Close) - stream, err := pb.NewDRPCOrdersClient(conn).SettlementWithWindow(ctx) - require.NoError(t, err) - defer ctx.Check(stream.Close) + stream, err := pb.NewDRPCOrdersClient(conn).SettlementWithWindow(ctx) + require.NoError(t, err) + defer ctx.Check(stream.Close) - // storagenode settles an order and orderlimit - err = stream.Send(&pb.SettlementRequest{ - Limit: orderLimit, - Order: order, - }) - require.NoError(t, err) - resp, err := stream.CloseAndRecv() - require.NoError(t, err) + // storagenode settles an order and orderlimit + err = stream.Send(&pb.SettlementRequest{ + Limit: orderLimit, + Order: order, + }) + require.NoError(t, err) + resp, err := stream.CloseAndRecv() + require.NoError(t, err) - expected := new(pb.SettlementWithWindowResponse) - switch { - case satellite.Config.Orders.WindowEndpointRolloutPhase != orders.WindowEndpointRolloutPhase3: - expected.Status = pb.SettlementWithWindowResponse_ACCEPTED - expected.ActionSettled = nil - case tt.expectedStatus == pb.SettlementWithWindowResponse_ACCEPTED: - expected.Status = pb.SettlementWithWindowResponse_ACCEPTED - expected.ActionSettled = map[int32]int64{int32(pb.PieceAction_PUT): tt.dataAmount} - default: - expected.Status = pb.SettlementWithWindowResponse_REJECTED - expected.ActionSettled = nil - } - require.Equal(t, expected, resp) + expected := new(pb.SettlementWithWindowResponse) + switch { + case satellite.Config.Orders.WindowEndpointRolloutPhase != orders.WindowEndpointRolloutPhase3: + expected.Status = pb.SettlementWithWindowResponse_ACCEPTED + expected.ActionSettled = nil + case tt.expectedStatus == pb.SettlementWithWindowResponse_ACCEPTED: + expected.Status = pb.SettlementWithWindowResponse_ACCEPTED + expected.ActionSettled = map[int32]int64{int32(pb.PieceAction_PUT): tt.dataAmount} + default: + expected.Status = pb.SettlementWithWindowResponse_REJECTED + expected.ActionSettled = nil + } + require.Equal(t, expected, resp) - // flush all the chores - assert.NoError(t, satellite.Accounting.ReportedRollup.RunOnce(ctx, now)) - satellite.Orders.Chore.Loop.TriggerWait() + // flush all the chores + assert.NoError(t, satellite.Accounting.ReportedRollup.RunOnce(ctx, now)) + satellite.Orders.Chore.Loop.TriggerWait() - // assert all the right stuff is in the satellite storagenode and bucket bandwidth tables - snbw, err = ordersDB.GetStorageNodeBandwidth(ctx, storagenode.ID(), time.Time{}, now) - require.NoError(t, err) - require.Equal(t, dataAmount, snbw) + // assert all the right stuff is in the satellite storagenode and bucket bandwidth tables + snbw, err = ordersDB.GetStorageNodeBandwidth(ctx, storagenode.ID(), time.Time{}, now) + require.NoError(t, err) + require.Equal(t, dataAmount, snbw) - newBbw, err := ordersDB.GetBucketBandwidth(ctx, projectID, []byte(bucketname), time.Time{}, now) - require.NoError(t, err) - require.Equal(t, dataAmount, newBbw) + newBbw, err := ordersDB.GetBucketBandwidth(ctx, projectID, []byte(bucketname), time.Time{}, now) + require.NoError(t, err) + require.Equal(t, dataAmount, newBbw) + }() } }) } diff --git a/satellite/satellitedb/projectaccounting.go b/satellite/satellitedb/projectaccounting.go index c918513e6..bb6d5781c 100644 --- a/satellite/satellitedb/projectaccounting.go +++ b/satellite/satellitedb/projectaccounting.go @@ -358,72 +358,78 @@ func (db *ProjectAccounting) GetBucketUsageRollups(ctx context.Context, projectI var bucketUsageRollups []accounting.BucketUsageRollup for _, bucket := range buckets { - bucketRollup := accounting.BucketUsageRollup{ - ProjectID: projectID, - BucketName: []byte(bucket), - Since: since, - Before: before, - } + err := func() error { + bucketRollup := accounting.BucketUsageRollup{ + ProjectID: projectID, + BucketName: []byte(bucket), + Since: since, + Before: before, + } - // get bucket_bandwidth_rollups - rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, projectID[:], []byte(bucket), since, before) - if err != nil { - return nil, err - } - defer func() { err = errs.Combine(err, rollupsRows.Close()) }() - - // fill egress - for rollupsRows.Next() { - var action pb.PieceAction - var settled, inline int64 - - err = rollupsRows.Scan(&settled, &inline, &action) + // get bucket_bandwidth_rollups + rollupsRows, err := db.db.QueryContext(ctx, roullupsQuery, projectID[:], []byte(bucket), since, before) if err != nil { - return nil, err + return err + } + defer func() { err = errs.Combine(err, rollupsRows.Close()) }() + + // fill egress + for rollupsRows.Next() { + var action pb.PieceAction + var settled, inline int64 + + err = rollupsRows.Scan(&settled, &inline, &action) + if err != nil { + return err + } + + switch action { + case pb.PieceAction_GET: + bucketRollup.GetEgress += memory.Size(settled + inline).GB() + case pb.PieceAction_GET_AUDIT: + bucketRollup.AuditEgress += memory.Size(settled + inline).GB() + case pb.PieceAction_GET_REPAIR: + bucketRollup.RepairEgress += memory.Size(settled + inline).GB() + default: + continue + } + } + if err := rollupsRows.Err(); err != nil { + return err } - switch action { - case pb.PieceAction_GET: - bucketRollup.GetEgress += memory.Size(settled + inline).GB() - case pb.PieceAction_GET_AUDIT: - bucketRollup.AuditEgress += memory.Size(settled + inline).GB() - case pb.PieceAction_GET_REPAIR: - bucketRollup.RepairEgress += memory.Size(settled + inline).GB() - default: - continue + bucketStorageTallies, err := storageQuery(ctx, + dbx.BucketStorageTally_ProjectId(projectID[:]), + dbx.BucketStorageTally_BucketName([]byte(bucket)), + dbx.BucketStorageTally_IntervalStart(since), + dbx.BucketStorageTally_IntervalStart(before)) + + if err != nil { + return err } - } - if err := rollupsRows.Err(); err != nil { - return nil, err - } - bucketStorageTallies, err := storageQuery(ctx, - dbx.BucketStorageTally_ProjectId(projectID[:]), - dbx.BucketStorageTally_BucketName([]byte(bucket)), - dbx.BucketStorageTally_IntervalStart(since), - dbx.BucketStorageTally_IntervalStart(before)) + // fill metadata, objects and stored data + // hours calculated from previous tallies, + // so we skip the most recent one + for i := len(bucketStorageTallies) - 1; i > 0; i-- { + current := bucketStorageTallies[i] + hours := bucketStorageTallies[i-1].IntervalStart.Sub(current.IntervalStart).Hours() + + bucketRollup.RemoteStoredData += memory.Size(current.Remote).GB() * hours + bucketRollup.InlineStoredData += memory.Size(current.Inline).GB() * hours + bucketRollup.MetadataSize += memory.Size(current.MetadataSize).GB() * hours + bucketRollup.RemoteSegments += float64(current.RemoteSegmentsCount) * hours + bucketRollup.InlineSegments += float64(current.InlineSegmentsCount) * hours + bucketRollup.ObjectCount += float64(current.ObjectCount) * hours + } + + bucketUsageRollups = append(bucketUsageRollups, bucketRollup) + return nil + }() if err != nil { return nil, err } - - // fill metadata, objects and stored data - // hours calculated from previous tallies, - // so we skip the most recent one - for i := len(bucketStorageTallies) - 1; i > 0; i-- { - current := bucketStorageTallies[i] - - hours := bucketStorageTallies[i-1].IntervalStart.Sub(current.IntervalStart).Hours() - - bucketRollup.RemoteStoredData += memory.Size(current.Remote).GB() * hours - bucketRollup.InlineStoredData += memory.Size(current.Inline).GB() * hours - bucketRollup.MetadataSize += memory.Size(current.MetadataSize).GB() * hours - bucketRollup.RemoteSegments += float64(current.RemoteSegmentsCount) * hours - bucketRollup.InlineSegments += float64(current.InlineSegmentsCount) * hours - bucketRollup.ObjectCount += float64(current.ObjectCount) * hours - } - - bucketUsageRollups = append(bucketUsageRollups, bucketRollup) } return bucketUsageRollups, nil diff --git a/storage/testsuite/test.go b/storage/testsuite/test.go index ab475910a..f817728c9 100644 --- a/storage/testsuite/test.go +++ b/storage/testsuite/test.go @@ -163,16 +163,18 @@ func testConstraints(t *testing.T, ctx *testcontext.Context, store storage.KeyVa {storage.Value("old-value"), nil}, {storage.Value("old-value"), storage.Value("new-value")}, } { - errTag := fmt.Sprintf("%d. %+v", i, tt) - key := storage.Key("test-key") - val := storage.Value("test-value") - defer func() { _ = store.Delete(ctx, key) }() + func() { + errTag := fmt.Sprintf("%d. %+v", i, tt) + key := storage.Key("test-key") + val := storage.Value("test-value") + defer func() { _ = store.Delete(ctx, key) }() - err := store.Put(ctx, key, val) - require.NoError(t, err, errTag) + err := store.Put(ctx, key, val) + require.NoError(t, err, errTag) - err = store.CompareAndSwap(ctx, key, tt.old, tt.new) - assert.True(t, storage.ErrValueChanged.Has(err), "%s: unexpected error: %+v", errTag, err) + err = store.CompareAndSwap(ctx, key, tt.old, tt.new) + assert.True(t, storage.ErrValueChanged.Has(err), "%s: unexpected error: %+v", errTag, err) + }() } }) diff --git a/storagenode/orders/service.go b/storagenode/orders/service.go index ba353d13b..4ec623b0a 100644 --- a/storagenode/orders/service.go +++ b/storagenode/orders/service.go @@ -398,7 +398,6 @@ func (service *Service) sendOrdersFromFileStore(ctx context.Context, now time.Ti var group errgroup.Group attemptedSatellites := 0 ctx, cancel := context.WithTimeout(ctx, service.config.SenderTimeout) - defer cancel() for satelliteID, unsentInfo := range ordersBySatellite { satelliteID, unsentInfo := satelliteID, unsentInfo @@ -430,6 +429,7 @@ func (service *Service) sendOrdersFromFileStore(ctx context.Context, now time.Ti } _ = group.Wait() // doesn't return errors + cancel() // if all satellites that orders need to be sent to are offline, exit and try again later. if attemptedSatellites == 0 { diff --git a/storagenode/piecestore/endpoint_test.go b/storagenode/piecestore/endpoint_test.go index f21518324..7d010ed73 100644 --- a/storagenode/piecestore/endpoint_test.go +++ b/storagenode/piecestore/endpoint_test.go @@ -50,23 +50,25 @@ func TestUploadAndPartialDownload(t *testing.T) { {1513, 1584}, {13581, 4783}, } { - if piecestore.DefaultConfig.InitialStep < tt.size { - t.Fatal("test expects initial step to be larger than size to download") - } - totalDownload += piecestore.DefaultConfig.InitialStep + func() { + if piecestore.DefaultConfig.InitialStep < tt.size { + t.Fatal("test expects initial step to be larger than size to download") + } + totalDownload += piecestore.DefaultConfig.InitialStep - download, cleanup, err := planet.Uplinks[0].DownloadStreamRange(ctx, planet.Satellites[0], "testbucket", "test/path", tt.offset, -1) - require.NoError(t, err) - defer ctx.Check(cleanup) + download, cleanup, err := planet.Uplinks[0].DownloadStreamRange(ctx, planet.Satellites[0], "testbucket", "test/path", tt.offset, -1) + require.NoError(t, err) + defer ctx.Check(cleanup) - data := make([]byte, tt.size) - n, err := io.ReadFull(download, data) - require.NoError(t, err) - assert.Equal(t, int(tt.size), n) + data := make([]byte, tt.size) + n, err := io.ReadFull(download, data) + require.NoError(t, err) + assert.Equal(t, int(tt.size), n) - assert.Equal(t, expectedData[tt.offset:tt.offset+tt.size], data) + assert.Equal(t, expectedData[tt.offset:tt.offset+tt.size], data) - require.NoError(t, download.Close()) + require.NoError(t, download.Close()) + }() } var totalBandwidthUsage bandwidth.Usage diff --git a/storagenode/piecestore/verification_test.go b/storagenode/piecestore/verification_test.go index b60a117eb..d27427e5c 100644 --- a/storagenode/piecestore/verification_test.go +++ b/storagenode/piecestore/verification_test.go @@ -226,45 +226,47 @@ func TestOrderLimitGetValidation(t *testing.T) { err: "expected get or get repair or audit action got PUT", }, } { - client, err := planet.Uplinks[0].DialPiecestore(ctx, planet.StorageNodes[0]) - require.NoError(t, err) - defer ctx.Check(client.Close) - - signer := signing.SignerFromFullIdentity(planet.Satellites[0].Identity) - satellite := planet.Satellites[0].Identity - if tt.satellite != nil { - signer = signing.SignerFromFullIdentity(tt.satellite) - satellite = tt.satellite - } - - orderLimit, piecePrivateKey := GenerateOrderLimit( - t, - satellite.ID, - planet.StorageNodes[0].ID(), - tt.pieceID, - tt.action, - tt.serialNumber, - tt.pieceExpiration, - tt.orderExpiration, - tt.limit, - ) - - orderLimit, err = signing.SignOrderLimit(ctx, signer, orderLimit) - require.NoError(t, err) - - downloader, err := client.Download(ctx, orderLimit, piecePrivateKey, 0, tt.limit) - require.NoError(t, err) - - buffer, readErr := ioutil.ReadAll(downloader) - closeErr := downloader.Close() - err = errs.Combine(readErr, closeErr) - if tt.err != "" { - assert.Equal(t, 0, len(buffer)) - require.Error(t, err) - require.Contains(t, err.Error(), tt.err) - } else { + func() { + client, err := planet.Uplinks[0].DialPiecestore(ctx, planet.StorageNodes[0]) require.NoError(t, err) - } + defer ctx.Check(client.Close) + + signer := signing.SignerFromFullIdentity(planet.Satellites[0].Identity) + satellite := planet.Satellites[0].Identity + if tt.satellite != nil { + signer = signing.SignerFromFullIdentity(tt.satellite) + satellite = tt.satellite + } + + orderLimit, piecePrivateKey := GenerateOrderLimit( + t, + satellite.ID, + planet.StorageNodes[0].ID(), + tt.pieceID, + tt.action, + tt.serialNumber, + tt.pieceExpiration, + tt.orderExpiration, + tt.limit, + ) + + orderLimit, err = signing.SignOrderLimit(ctx, signer, orderLimit) + require.NoError(t, err) + + downloader, err := client.Download(ctx, orderLimit, piecePrivateKey, 0, tt.limit) + require.NoError(t, err) + + buffer, readErr := ioutil.ReadAll(downloader) + closeErr := downloader.Close() + err = errs.Combine(readErr, closeErr) + if tt.err != "" { + assert.Equal(t, 0, len(buffer)) + require.Error(t, err) + require.Contains(t, err.Error(), tt.err) + } else { + require.NoError(t, err) + } + }() } }) } diff --git a/storagenode/storagenodedb/database.go b/storagenode/storagenodedb/database.go index 3e3824d0f..4b6b31bc6 100644 --- a/storagenode/storagenodedb/database.go +++ b/storagenode/storagenodedb/database.go @@ -349,103 +349,111 @@ func (db *DB) MigrateToLatest(ctx context.Context) error { // Preflight conducts a pre-flight check to ensure correct schemas and minimal read+write functionality of the database tables. func (db *DB) Preflight(ctx context.Context) (err error) { for dbName, dbContainer := range db.SQLDBs { - nextDB := dbContainer.GetDB() - // Preflight stage 1: test schema correctness - schema, err := sqliteutil.QuerySchema(ctx, nextDB) - if err != nil { - return ErrPreflight.New("database %q: schema check failed: %v", dbName, err) - } - // we don't care about changes in versions table - schema.DropTable("versions") - // if there was a previous pre-flight failure, test_table might still be in the schema - schema.DropTable("test_table") - - // If tables and indexes of the schema are empty, set to nil - // to help with comparison to the snapshot. - if len(schema.Tables) == 0 { - schema.Tables = nil - } - if len(schema.Indexes) == 0 { - schema.Indexes = nil - } - - // get expected schema - expectedSchema := Schema()[dbName] - - // find extra indexes - var extraIdxs []*dbschema.Index - for _, idx := range schema.Indexes { - if _, exists := expectedSchema.FindIndex(idx.Name); exists { - continue - } - - extraIdxs = append(extraIdxs, idx) - } - // drop index from schema if it is not unique to not fail preflight - for _, idx := range extraIdxs { - if !idx.Unique { - schema.DropIndex(idx.Name) - } - } - // warn that schema contains unexpected indexes - if len(extraIdxs) > 0 { - db.log.Warn(fmt.Sprintf("database %q: schema contains unexpected indices %v", dbName, extraIdxs)) - } - - // expect expected schema to match actual schema - if diff := cmp.Diff(expectedSchema, schema); diff != "" { - return ErrPreflight.New("database %q: expected schema does not match actual: %s", dbName, diff) - } - - // Preflight stage 2: test basic read/write access - // for each database, create a new table, insert a row into that table, retrieve and validate that row, and drop the table. - - // drop test table in case the last preflight check failed before table could be dropped - _, err = nextDB.ExecContext(ctx, "DROP TABLE IF EXISTS test_table") - if err != nil { - return ErrPreflight.New("database %q: failed drop if test_table: %w", dbName, err) - } - _, err = nextDB.ExecContext(ctx, "CREATE TABLE test_table(id int NOT NULL, name varchar(30), PRIMARY KEY (id))") - if err != nil { - return ErrPreflight.New("database %q: failed create test_table: %w", dbName, err) - } - - var expectedID, actualID int - var expectedName, actualName string - expectedID = 1 - expectedName = "TEST" - _, err = nextDB.ExecContext(ctx, "INSERT INTO test_table VALUES ( ?, ? )", expectedID, expectedName) - if err != nil { - return ErrPreflight.New("database: %q: failed inserting test value: %w", dbName, err) - } - - rows, err := nextDB.QueryContext(ctx, "SELECT id, name FROM test_table") - if err != nil { - return ErrPreflight.New("database: %q: failed selecting test value: %w", dbName, err) - } - defer func() { err = errs.Combine(err, rows.Close()) }() - if !rows.Next() { - return ErrPreflight.New("database %q: no rows in test_table", dbName) - } - err = rows.Scan(&actualID, &actualName) - if err != nil { - return ErrPreflight.New("database %q: failed scanning row: %w", dbName, err) - } - if expectedID != actualID || expectedName != actualName { - return ErrPreflight.New("database %q: expected (%d, '%s'), actual (%d, '%s')", dbName, expectedID, expectedName, actualID, actualName) - } - if rows.Next() { - return ErrPreflight.New("database %q: more than one row in test_table", dbName) - } - - _, err = nextDB.ExecContext(ctx, "DROP TABLE test_table") - if err != nil { - return ErrPreflight.New("database %q: failed drop test_table %w", dbName, err) + if err := db.preflight(ctx, dbName, dbContainer); err != nil { + return err } } return nil } +func (db *DB) preflight(ctx context.Context, dbName string, dbContainer DBContainer) error { + nextDB := dbContainer.GetDB() + // Preflight stage 1: test schema correctness + schema, err := sqliteutil.QuerySchema(ctx, nextDB) + if err != nil { + return ErrPreflight.New("database %q: schema check failed: %v", dbName, err) + } + // we don't care about changes in versions table + schema.DropTable("versions") + // if there was a previous pre-flight failure, test_table might still be in the schema + schema.DropTable("test_table") + + // If tables and indexes of the schema are empty, set to nil + // to help with comparison to the snapshot. + if len(schema.Tables) == 0 { + schema.Tables = nil + } + if len(schema.Indexes) == 0 { + schema.Indexes = nil + } + + // get expected schema + expectedSchema := Schema()[dbName] + + // find extra indexes + var extraIdxs []*dbschema.Index + for _, idx := range schema.Indexes { + if _, exists := expectedSchema.FindIndex(idx.Name); exists { + continue + } + + extraIdxs = append(extraIdxs, idx) + } + // drop index from schema if it is not unique to not fail preflight + for _, idx := range extraIdxs { + if !idx.Unique { + schema.DropIndex(idx.Name) + } + } + // warn that schema contains unexpected indexes + if len(extraIdxs) > 0 { + db.log.Warn(fmt.Sprintf("database %q: schema contains unexpected indices %v", dbName, extraIdxs)) + } + + // expect expected schema to match actual schema + if diff := cmp.Diff(expectedSchema, schema); diff != "" { + return ErrPreflight.New("database %q: expected schema does not match actual: %s", dbName, diff) + } + + // Preflight stage 2: test basic read/write access + // for each database, create a new table, insert a row into that table, retrieve and validate that row, and drop the table. + + // drop test table in case the last preflight check failed before table could be dropped + _, err = nextDB.ExecContext(ctx, "DROP TABLE IF EXISTS test_table") + if err != nil { + return ErrPreflight.New("database %q: failed drop if test_table: %w", dbName, err) + } + _, err = nextDB.ExecContext(ctx, "CREATE TABLE test_table(id int NOT NULL, name varchar(30), PRIMARY KEY (id))") + if err != nil { + return ErrPreflight.New("database %q: failed create test_table: %w", dbName, err) + } + + var expectedID, actualID int + var expectedName, actualName string + expectedID = 1 + expectedName = "TEST" + _, err = nextDB.ExecContext(ctx, "INSERT INTO test_table VALUES ( ?, ? )", expectedID, expectedName) + if err != nil { + return ErrPreflight.New("database: %q: failed inserting test value: %w", dbName, err) + } + + rows, err := nextDB.QueryContext(ctx, "SELECT id, name FROM test_table") + if err != nil { + return ErrPreflight.New("database: %q: failed selecting test value: %w", dbName, err) + } + defer func() { err = errs.Combine(err, rows.Close()) }() + if !rows.Next() { + return ErrPreflight.New("database %q: no rows in test_table", dbName) + } + err = rows.Scan(&actualID, &actualName) + if err != nil { + return ErrPreflight.New("database %q: failed scanning row: %w", dbName, err) + } + if expectedID != actualID || expectedName != actualName { + return ErrPreflight.New("database %q: expected (%d, '%s'), actual (%d, '%s')", dbName, expectedID, expectedName, actualID, actualName) + } + if rows.Next() { + return ErrPreflight.New("database %q: more than one row in test_table", dbName) + } + + _, err = nextDB.ExecContext(ctx, "DROP TABLE test_table") + if err != nil { + return ErrPreflight.New("database %q: failed drop test_table %w", dbName, err) + } + + return nil +} + // Close closes any resources. func (db *DB) Close() error { return db.closeDatabases()