storj/private/dbutil/cockroachutil/driver.go
Natalie Villasana 8d87a6efc9 cockroachutil/driver: handle retryable errors returned from Next
This will only work if retryable errors are returned on the first
call to Next. Otherwise if they're returned later, we will need
deeper changes at the application code level throughout the
codebase 😬👎

Change-Id: I46d795a13670f66b7f085605ba1b779f69c339c3
2020-05-15 14:49:43 -04:00

347 lines
9.6 KiB
Go

// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package cockroachutil
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"io"
"reflect"
"strings"
"github.com/lib/pq"
"github.com/zeebo/errs"
)
// Driver is the type for the "cockroach" sql/database driver.
// It uses github.com/lib/pq under the covers because of Cockroach's
// PostgreSQL compatibility, but allows differentiation between pg and
// crdb connections.
type Driver struct {
pq.Driver
}
// Open opens a new cockroachDB connection.
func (cd *Driver) Open(name string) (driver.Conn, error) {
name = translateName(name)
return pq.Open(name)
}
// OpenConnector obtains a new db Connector, which sql.DB can use to
// obtain each needed connection at the appropriate time.
func (cd *Driver) OpenConnector(name string) (driver.Connector, error) {
name = translateName(name)
pgConnector, err := pq.NewConnector(name)
if err != nil {
return nil, err
}
return &cockroachConnector{pgConnector}, nil
}
// cockroachConnector is a thin wrapper around a pq-based connector. This allows
// Driver to supply our custom cockroachConn type for connections.
type cockroachConnector struct {
pgConnector driver.Connector
}
// Driver returns the driver being used for this connector.
func (c *cockroachConnector) Driver() driver.Driver {
return &Driver{}
}
// Connect creates a new connection using the connector.
func (c *cockroachConnector) Connect(ctx context.Context) (driver.Conn, error) {
pgConn, err := c.pgConnector.Connect(ctx)
if err != nil {
return nil, err
}
if pgConnAll, ok := pgConn.(connAll); ok {
return &cockroachConn{pgConnAll}, nil
}
return nil, errs.New("Underlying connector type %T does not implement connAll?!", pgConn)
}
type connAll interface {
driver.Conn
driver.ConnBeginTx
driver.ExecerContext
driver.QueryerContext
}
// cockroachConn is a connection to a database. It is not used concurrently by multiple goroutines.
type cockroachConn struct {
underlying connAll
}
// Assert that cockroachConn fulfills connAll.
var _ connAll = (*cockroachConn)(nil)
// Close closes the cockroachConn.
func (c *cockroachConn) Close() error {
return c.underlying.Close()
}
// ExecContext (when implemented by a driver.Conn) provides ExecContext
// functionality to a sql.DB instance. This implementation provides
// retry semantics for single statements.
func (c *cockroachConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
result, err := c.underlying.ExecContext(ctx, query, args)
for err != nil && !c.isInTransaction() && needsRetry(err) {
result, err = c.underlying.ExecContext(ctx, query, args)
}
return result, err
}
type cockroachRows struct {
rows driver.Rows
firstResults []driver.Value
eof bool
}
// Columns returns the names of the columns.
func (rows *cockroachRows) Columns() []string {
return rows.rows.Columns()
}
// Close closes the rows iterator.
func (rows *cockroachRows) Close() error {
return rows.rows.Close()
}
// Next implements the Next method on driver.Rows.
func (rows *cockroachRows) Next(dest []driver.Value) error {
if rows.eof {
return io.EOF
}
if rows.firstResults == nil {
return rows.rows.Next(dest)
}
copy(dest, rows.firstResults)
rows.firstResults = nil
return nil
}
func wrapRows(rows driver.Rows) (crdbRows *cockroachRows, err error) {
columns := rows.Columns()
dest := make([]driver.Value, len(columns))
err = rows.Next(dest)
if err != nil {
if err == io.EOF {
return &cockroachRows{rows: rows, firstResults: nil, eof: true}, nil
}
return nil, err
}
return &cockroachRows{rows: rows, firstResults: dest}, nil
}
// QueryContext (when implemented by a driver.Conn) provides QueryContext
// functionality to a sql.DB instance. This implementation provides
// retry semantics for single statements.
func (c *cockroachConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
for {
result, err := c.underlying.QueryContext(ctx, query, args)
if err != nil {
if needsRetry(err) {
if c.isInTransaction() {
return nil, err
}
continue
}
return nil, err
}
wrappedResult, err := wrapRows(result)
if err != nil {
// If this returns an error it's probably the same error
// we got from calling Next inside wrapRows.
_ = result.Close()
if needsRetry(err) {
if c.isInTransaction() {
return nil, err
}
continue
}
return nil, err
}
return wrappedResult, nil
}
}
// Begin starts a new transaction.
func (c *cockroachConn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
// BeginTx begins a new transaction using the specified context and with the specified options.
func (c *cockroachConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
return c.underlying.BeginTx(ctx, opts)
}
// Prepare prepares a statement for future execution.
func (c *cockroachConn) Prepare(query string) (driver.Stmt, error) {
pqStmt, err := c.underlying.Prepare(query)
if err != nil {
return nil, err
}
adapted, ok := pqStmt.(stmtAll)
if !ok {
return nil, errs.New("Stmt type %T does not provide stmtAll?!", adapted)
}
return &cockroachStmt{underlyingStmt: adapted, conn: c}, nil
}
type transactionStatus byte
const (
txnStatusIdle transactionStatus = 'I'
txnStatusIdleInTransaction transactionStatus = 'T'
txnStatusInFailedTransaction transactionStatus = 'E'
)
func (c *cockroachConn) txnStatus() transactionStatus {
// access c.underlying -> c.underlying.(*pq.conn) -> (*c.underlying.(*pq.conn)).txnStatus
//
// this is of course brittle if lib/pq internals change, so a test is necessary to make
// sure we stay on the same page.
return transactionStatus(reflect.ValueOf(c.underlying).Elem().Field(4).Uint())
}
func (c *cockroachConn) isInTransaction() bool {
txnStatus := c.txnStatus()
return txnStatus == txnStatusIdleInTransaction || txnStatus == txnStatusInFailedTransaction
}
type stmtAll interface {
driver.Stmt
driver.StmtExecContext
driver.StmtQueryContext
}
type cockroachStmt struct {
underlyingStmt stmtAll
conn *cockroachConn
}
// Assert that cockroachStmt satisfies StmtExecContext and StmtQueryContext.
var _ stmtAll = (*cockroachStmt)(nil)
// Close closes a prepared statement.
func (stmt *cockroachStmt) Close() error {
return stmt.underlyingStmt.Close()
}
// NumInput returns the number of placeholder parameters.
func (stmt *cockroachStmt) NumInput() int {
return stmt.underlyingStmt.NumInput()
}
// Exec executes a SQL statement in the background context.
func (stmt *cockroachStmt) Exec(args []driver.Value) (driver.Result, error) {
// since (driver.Stmt).Exec() is deprecated, we translate our Value args to NamedValue args
// and pass in background context to ExecContext instead.
namedArgs := make([]driver.NamedValue, len(args))
for i, arg := range args {
namedArgs[i] = driver.NamedValue{Ordinal: i + 1, Value: arg}
}
result, err := stmt.underlyingStmt.ExecContext(context.Background(), namedArgs)
for err != nil && !stmt.conn.isInTransaction() && needsRetry(err) {
result, err = stmt.underlyingStmt.ExecContext(context.Background(), namedArgs)
}
return result, err
}
// Query executes a query in the background context.
func (stmt *cockroachStmt) Query(args []driver.Value) (driver.Rows, error) {
// since (driver.Stmt).Query() is deprecated, we translate our Value args to NamedValue args
// and pass in background context to QueryContext instead.
namedArgs := make([]driver.NamedValue, len(args))
for i, arg := range args {
namedArgs[i] = driver.NamedValue{Ordinal: i + 1, Value: arg}
}
return stmt.QueryContext(context.Background(), namedArgs)
}
// ExecContext executes SQL statements in the specified context.
func (stmt *cockroachStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
result, err := stmt.underlyingStmt.ExecContext(ctx, args)
for err != nil && !stmt.conn.isInTransaction() && needsRetry(err) {
result, err = stmt.underlyingStmt.ExecContext(ctx, args)
}
return result, err
}
// QueryContext executes a query in the specified context.
func (stmt *cockroachStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
for {
result, err := stmt.underlyingStmt.QueryContext(ctx, args)
if err != nil {
if needsRetry(err) {
if stmt.conn.isInTransaction() {
return nil, err
}
continue
}
return nil, err
}
wrappedResult, err := wrapRows(result)
if err != nil {
// If this returns an error it's probably the same error
// we got from calling Next inside wrapRows.
_ = result.Close()
if needsRetry(err) {
if stmt.conn.isInTransaction() {
return nil, err
}
continue
}
return nil, err
}
return wrappedResult, nil
}
}
// translateName changes the scheme name in a `cockroach://` URL to
// `postgres://`, as that is what lib/pq will expect.
func translateName(name string) string {
if strings.HasPrefix(name, "cockroach://") {
name = "postgres://" + name[12:]
}
return name
}
// borrowed from code in crdb
func needsRetry(err error) bool {
code := errCode(err)
return code == "40001" || code == "CR000"
}
// borrowed from crdb
func errCode(err error) string {
switch t := errorCause(err).(type) {
case *pq.Error:
return string(t.Code)
default:
return ""
}
}
func errorCause(err error) error {
for err != nil {
cause := errors.Unwrap(err)
if cause == nil {
break
}
err = cause
}
return err
}
// Assert that Driver satisfies DriverContext.
var _ driver.DriverContext = &Driver{}
func init() {
sql.Register("cockroach", &Driver{})
}