storj/private/dbutil/cockroachutil/driver.go

362 lines
10 KiB
Go
Raw Normal View History

// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package cockroachutil
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"io"
"reflect"
"strings"
"github.com/lib/pq"
"github.com/zeebo/errs"
)
// Driver is the type for the "cockroach" sql/database driver.
// It uses github.com/lib/pq under the covers because of Cockroach's
// PostgreSQL compatibility, but allows differentiation between pg and
// crdb connections.
type Driver struct {
pq.Driver
}
// Open opens a new cockroachDB connection.
func (cd *Driver) Open(name string) (driver.Conn, error) {
name = translateName(name)
return pq.Open(name)
}
// OpenConnector obtains a new db Connector, which sql.DB can use to
// obtain each needed connection at the appropriate time.
func (cd *Driver) OpenConnector(name string) (driver.Connector, error) {
name = translateName(name)
pgConnector, err := pq.NewConnector(name)
if err != nil {
return nil, err
}
return &cockroachConnector{pgConnector}, nil
}
// cockroachConnector is a thin wrapper around a pq-based connector. This allows
// Driver to supply our custom cockroachConn type for connections.
type cockroachConnector struct {
pgConnector driver.Connector
}
// Driver returns the driver being used for this connector.
func (c *cockroachConnector) Driver() driver.Driver {
return &Driver{}
}
// Connect creates a new connection using the connector.
func (c *cockroachConnector) Connect(ctx context.Context) (driver.Conn, error) {
pgConn, err := c.pgConnector.Connect(ctx)
if err != nil {
return nil, err
}
if pgConnAll, ok := pgConn.(connAll); ok {
return &cockroachConn{pgConnAll}, nil
}
return nil, errs.New("Underlying connector type %T does not implement connAll?!", pgConn)
}
type connAll interface {
driver.Conn
driver.ConnBeginTx
driver.ExecerContext
driver.QueryerContext
}
// cockroachConn is a connection to a database. It is not used concurrently by multiple goroutines.
type cockroachConn struct {
underlying connAll
}
// Assert that cockroachConn fulfills connAll.
var _ connAll = (*cockroachConn)(nil)
// Close closes the cockroachConn.
func (c *cockroachConn) Close() error {
return c.underlying.Close()
}
// ExecContext (when implemented by a driver.Conn) provides ExecContext
// functionality to a sql.DB instance. This implementation provides
// retry semantics for single statements.
func (c *cockroachConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
result, err := c.underlying.ExecContext(ctx, query, args)
for err != nil && !c.isInTransaction() && NeedsRetry(err) {
mon.Event("needed_retry")
result, err = c.underlying.ExecContext(ctx, query, args)
}
return result, err
}
type cockroachRows struct {
rows driver.Rows
firstResults []driver.Value
eof bool
}
// Columns returns the names of the columns.
func (rows *cockroachRows) Columns() []string {
return rows.rows.Columns()
}
// Close closes the rows iterator.
func (rows *cockroachRows) Close() error {
return rows.rows.Close()
}
// Next implements the Next method on driver.Rows.
func (rows *cockroachRows) Next(dest []driver.Value) error {
if rows.eof {
return io.EOF
}
if rows.firstResults == nil {
return rows.rows.Next(dest)
}
copy(dest, rows.firstResults)
rows.firstResults = nil
return nil
}
func wrapRows(rows driver.Rows) (crdbRows *cockroachRows, err error) {
columns := rows.Columns()
dest := make([]driver.Value, len(columns))
err = rows.Next(dest)
if err != nil {
if err == io.EOF {
return &cockroachRows{rows: rows, firstResults: nil, eof: true}, nil
}
return nil, err
}
return &cockroachRows{rows: rows, firstResults: dest}, nil
}
// QueryContext (when implemented by a driver.Conn) provides QueryContext
// functionality to a sql.DB instance. This implementation provides
// retry semantics for single statements.
func (c *cockroachConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (_ driver.Rows, err error) {
defer mon.Task()(&ctx)(&err)
for {
result, err := c.underlying.QueryContext(ctx, query, args)
if err != nil {
if NeedsRetry(err) {
if c.isInTransaction() {
return nil, err
}
mon.Event("needed_retry")
continue
}
return nil, err
}
wrappedResult, err := wrapRows(result)
if err != nil {
// If this returns an error it's probably the same error
// we got from calling Next inside wrapRows.
_ = result.Close()
if NeedsRetry(err) {
if c.isInTransaction() {
return nil, err
}
mon.Event("needed_retry")
continue
}
return nil, err
}
return wrappedResult, nil
}
}
// Begin starts a new transaction.
func (c *cockroachConn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
// BeginTx begins a new transaction using the specified context and with the specified options.
func (c *cockroachConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
return c.underlying.BeginTx(ctx, opts)
}
// Prepare prepares a statement for future execution.
func (c *cockroachConn) Prepare(query string) (driver.Stmt, error) {
pqStmt, err := c.underlying.Prepare(query)
if err != nil {
return nil, err
}
adapted, ok := pqStmt.(stmtAll)
if !ok {
return nil, errs.New("Stmt type %T does not provide stmtAll?!", adapted)
}
return &cockroachStmt{underlyingStmt: adapted, conn: c}, nil
}
type transactionStatus byte
const (
txnStatusIdle transactionStatus = 'I'
txnStatusIdleInTransaction transactionStatus = 'T'
txnStatusInFailedTransaction transactionStatus = 'E'
)
func (c *cockroachConn) txnStatus() transactionStatus {
// access c.underlying -> c.underlying.(*pq.conn) -> (*c.underlying.(*pq.conn)).txnStatus
//
// this is of course brittle if lib/pq internals change, so a test is necessary to make
// sure we stay on the same page.
return transactionStatus(reflect.ValueOf(c.underlying).Elem().Field(4).Uint())
}
func (c *cockroachConn) isInTransaction() bool {
txnStatus := c.txnStatus()
return txnStatus == txnStatusIdleInTransaction || txnStatus == txnStatusInFailedTransaction
}
type stmtAll interface {
driver.Stmt
driver.StmtExecContext
driver.StmtQueryContext
}
type cockroachStmt struct {
underlyingStmt stmtAll
conn *cockroachConn
}
// Assert that cockroachStmt satisfies StmtExecContext and StmtQueryContext.
var _ stmtAll = (*cockroachStmt)(nil)
// Close closes a prepared statement.
func (stmt *cockroachStmt) Close() error {
return stmt.underlyingStmt.Close()
}
// NumInput returns the number of placeholder parameters.
func (stmt *cockroachStmt) NumInput() int {
return stmt.underlyingStmt.NumInput()
}
// Exec executes a SQL statement in the background context.
func (stmt *cockroachStmt) Exec(args []driver.Value) (driver.Result, error) {
// since (driver.Stmt).Exec() is deprecated, we translate our Value args to NamedValue args
// and pass in background context to ExecContext instead.
namedArgs := make([]driver.NamedValue, len(args))
for i, arg := range args {
namedArgs[i] = driver.NamedValue{Ordinal: i + 1, Value: arg}
}
result, err := stmt.underlyingStmt.ExecContext(context.Background(), namedArgs)
for err != nil && !stmt.conn.isInTransaction() && NeedsRetry(err) {
mon.Event("needed_retry")
result, err = stmt.underlyingStmt.ExecContext(context.Background(), namedArgs)
}
return result, err
}
// Query executes a query in the background context.
func (stmt *cockroachStmt) Query(args []driver.Value) (driver.Rows, error) {
// since (driver.Stmt).Query() is deprecated, we translate our Value args to NamedValue args
// and pass in background context to QueryContext instead.
namedArgs := make([]driver.NamedValue, len(args))
for i, arg := range args {
namedArgs[i] = driver.NamedValue{Ordinal: i + 1, Value: arg}
}
return stmt.QueryContext(context.Background(), namedArgs)
}
// ExecContext executes SQL statements in the specified context.
func (stmt *cockroachStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
result, err := stmt.underlyingStmt.ExecContext(ctx, args)
for err != nil && !stmt.conn.isInTransaction() && NeedsRetry(err) {
mon.Event("needed_retry")
result, err = stmt.underlyingStmt.ExecContext(ctx, args)
}
return result, err
}
// QueryContext executes a query in the specified context.
func (stmt *cockroachStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (_ driver.Rows, err error) {
defer mon.Task()(&ctx)(&err)
for {
result, err := stmt.underlyingStmt.QueryContext(ctx, args)
if err != nil {
if NeedsRetry(err) {
if stmt.conn.isInTransaction() {
return nil, err
}
mon.Event("needed_retry")
continue
}
return nil, err
}
wrappedResult, err := wrapRows(result)
if err != nil {
// If this returns an error it's probably the same error
// we got from calling Next inside wrapRows.
_ = result.Close()
if NeedsRetry(err) {
if stmt.conn.isInTransaction() {
return nil, err
}
mon.Event("needed_retry")
continue
}
return nil, err
}
return wrappedResult, nil
}
}
// translateName changes the scheme name in a `cockroach://` URL to
// `postgres://`, as that is what lib/pq will expect.
func translateName(name string) string {
if strings.HasPrefix(name, "cockroach://") {
name = "postgres://" + name[12:]
}
return name
}
// NeedsRetry checks if the error code means a retry is needed,
// borrowed from code in crdb.
func NeedsRetry(err error) bool {
code := errCode(err)
// 57P01 occurs when a CRDB node rejoins the cluster but is not ready to accept connections
// CRDB support recommended a retry at this point
// Support ticket: https://support.cockroachlabs.com/hc/en-us/requests/5510
// TODO re-evaluate this if support provides a better solution
return code == "40001" || code == "CR000" || code == "57P01"
}
// borrowed from crdb
func errCode(err error) string {
switch t := errorCause(err).(type) {
case *pq.Error:
return string(t.Code)
default:
return ""
}
}
func errorCause(err error) error {
for err != nil {
cause := errors.Unwrap(err)
if cause == nil {
break
}
err = cause
}
return err
}
// Assert that Driver satisfies DriverContext.
var _ driver.DriverContext = &Driver{}
func init() {
sql.Register("cockroach", &Driver{})
}