// Copyright (C) 2019 Storj Labs, Inc. // See LICENSE for copying information. package cockroachutil import ( "context" "database/sql" "database/sql/driver" "errors" "io" "reflect" "strings" "github.com/lib/pq" "github.com/zeebo/errs" ) // Driver is the type for the "cockroach" sql/database driver. // It uses github.com/lib/pq under the covers because of Cockroach's // PostgreSQL compatibility, but allows differentiation between pg and // crdb connections. type Driver struct { pq.Driver } // Open opens a new cockroachDB connection. func (cd *Driver) Open(name string) (driver.Conn, error) { name = translateName(name) return pq.Open(name) } // OpenConnector obtains a new db Connector, which sql.DB can use to // obtain each needed connection at the appropriate time. func (cd *Driver) OpenConnector(name string) (driver.Connector, error) { name = translateName(name) pgConnector, err := pq.NewConnector(name) if err != nil { return nil, err } return &cockroachConnector{pgConnector}, nil } // cockroachConnector is a thin wrapper around a pq-based connector. This allows // Driver to supply our custom cockroachConn type for connections. type cockroachConnector struct { pgConnector driver.Connector } // Driver returns the driver being used for this connector. func (c *cockroachConnector) Driver() driver.Driver { return &Driver{} } // Connect creates a new connection using the connector. func (c *cockroachConnector) Connect(ctx context.Context) (driver.Conn, error) { pgConn, err := c.pgConnector.Connect(ctx) if err != nil { return nil, err } if pgConnAll, ok := pgConn.(connAll); ok { return &cockroachConn{pgConnAll}, nil } return nil, errs.New("Underlying connector type %T does not implement connAll?!", pgConn) } type connAll interface { driver.Conn driver.ConnBeginTx driver.ExecerContext driver.QueryerContext } // cockroachConn is a connection to a database. It is not used concurrently by multiple goroutines. type cockroachConn struct { underlying connAll } // Assert that cockroachConn fulfills connAll. var _ connAll = (*cockroachConn)(nil) // Close closes the cockroachConn. func (c *cockroachConn) Close() error { return c.underlying.Close() } // ExecContext (when implemented by a driver.Conn) provides ExecContext // functionality to a sql.DB instance. This implementation provides // retry semantics for single statements. func (c *cockroachConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { result, err := c.underlying.ExecContext(ctx, query, args) for err != nil && !c.isInTransaction() && NeedsRetry(err) { mon.Event("needed_retry") result, err = c.underlying.ExecContext(ctx, query, args) } return result, err } type cockroachRows struct { rows driver.Rows firstResults []driver.Value eof bool } // Columns returns the names of the columns. func (rows *cockroachRows) Columns() []string { return rows.rows.Columns() } // Close closes the rows iterator. func (rows *cockroachRows) Close() error { return rows.rows.Close() } // Next implements the Next method on driver.Rows. func (rows *cockroachRows) Next(dest []driver.Value) error { if rows.eof { return io.EOF } if rows.firstResults == nil { return rows.rows.Next(dest) } copy(dest, rows.firstResults) rows.firstResults = nil return nil } func wrapRows(rows driver.Rows) (crdbRows *cockroachRows, err error) { columns := rows.Columns() dest := make([]driver.Value, len(columns)) err = rows.Next(dest) if err != nil { if err == io.EOF { return &cockroachRows{rows: rows, firstResults: nil, eof: true}, nil } return nil, err } return &cockroachRows{rows: rows, firstResults: dest}, nil } // QueryContext (when implemented by a driver.Conn) provides QueryContext // functionality to a sql.DB instance. This implementation provides // retry semantics for single statements. func (c *cockroachConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (_ driver.Rows, err error) { defer mon.Task()(&ctx)(&err) for { result, err := c.underlying.QueryContext(ctx, query, args) if err != nil { if NeedsRetry(err) { if c.isInTransaction() { return nil, err } mon.Event("needed_retry") continue } return nil, err } wrappedResult, err := wrapRows(result) if err != nil { // If this returns an error it's probably the same error // we got from calling Next inside wrapRows. _ = result.Close() if NeedsRetry(err) { if c.isInTransaction() { return nil, err } mon.Event("needed_retry") continue } return nil, err } return wrappedResult, nil } } // Begin starts a new transaction. func (c *cockroachConn) Begin() (driver.Tx, error) { return c.BeginTx(context.Background(), driver.TxOptions{}) } // BeginTx begins a new transaction using the specified context and with the specified options. func (c *cockroachConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { return c.underlying.BeginTx(ctx, opts) } // Prepare prepares a statement for future execution. func (c *cockroachConn) Prepare(query string) (driver.Stmt, error) { pqStmt, err := c.underlying.Prepare(query) if err != nil { return nil, err } adapted, ok := pqStmt.(stmtAll) if !ok { return nil, errs.New("Stmt type %T does not provide stmtAll?!", adapted) } return &cockroachStmt{underlyingStmt: adapted, conn: c}, nil } type transactionStatus byte const ( txnStatusIdle transactionStatus = 'I' txnStatusIdleInTransaction transactionStatus = 'T' txnStatusInFailedTransaction transactionStatus = 'E' ) func (c *cockroachConn) txnStatus() transactionStatus { // access c.underlying -> c.underlying.(*pq.conn) -> (*c.underlying.(*pq.conn)).txnStatus // // this is of course brittle if lib/pq internals change, so a test is necessary to make // sure we stay on the same page. return transactionStatus(reflect.ValueOf(c.underlying).Elem().Field(4).Uint()) } func (c *cockroachConn) isInTransaction() bool { txnStatus := c.txnStatus() return txnStatus == txnStatusIdleInTransaction || txnStatus == txnStatusInFailedTransaction } type stmtAll interface { driver.Stmt driver.StmtExecContext driver.StmtQueryContext } type cockroachStmt struct { underlyingStmt stmtAll conn *cockroachConn } // Assert that cockroachStmt satisfies StmtExecContext and StmtQueryContext. var _ stmtAll = (*cockroachStmt)(nil) // Close closes a prepared statement. func (stmt *cockroachStmt) Close() error { return stmt.underlyingStmt.Close() } // NumInput returns the number of placeholder parameters. func (stmt *cockroachStmt) NumInput() int { return stmt.underlyingStmt.NumInput() } // Exec executes a SQL statement in the background context. func (stmt *cockroachStmt) Exec(args []driver.Value) (driver.Result, error) { // since (driver.Stmt).Exec() is deprecated, we translate our Value args to NamedValue args // and pass in background context to ExecContext instead. namedArgs := make([]driver.NamedValue, len(args)) for i, arg := range args { namedArgs[i] = driver.NamedValue{Ordinal: i + 1, Value: arg} } result, err := stmt.underlyingStmt.ExecContext(context.Background(), namedArgs) for err != nil && !stmt.conn.isInTransaction() && NeedsRetry(err) { mon.Event("needed_retry") result, err = stmt.underlyingStmt.ExecContext(context.Background(), namedArgs) } return result, err } // Query executes a query in the background context. func (stmt *cockroachStmt) Query(args []driver.Value) (driver.Rows, error) { // since (driver.Stmt).Query() is deprecated, we translate our Value args to NamedValue args // and pass in background context to QueryContext instead. namedArgs := make([]driver.NamedValue, len(args)) for i, arg := range args { namedArgs[i] = driver.NamedValue{Ordinal: i + 1, Value: arg} } return stmt.QueryContext(context.Background(), namedArgs) } // ExecContext executes SQL statements in the specified context. func (stmt *cockroachStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { result, err := stmt.underlyingStmt.ExecContext(ctx, args) for err != nil && !stmt.conn.isInTransaction() && NeedsRetry(err) { mon.Event("needed_retry") result, err = stmt.underlyingStmt.ExecContext(ctx, args) } return result, err } // QueryContext executes a query in the specified context. func (stmt *cockroachStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (_ driver.Rows, err error) { defer mon.Task()(&ctx)(&err) for { result, err := stmt.underlyingStmt.QueryContext(ctx, args) if err != nil { if NeedsRetry(err) { if stmt.conn.isInTransaction() { return nil, err } mon.Event("needed_retry") continue } return nil, err } wrappedResult, err := wrapRows(result) if err != nil { // If this returns an error it's probably the same error // we got from calling Next inside wrapRows. _ = result.Close() if NeedsRetry(err) { if stmt.conn.isInTransaction() { return nil, err } mon.Event("needed_retry") continue } return nil, err } return wrappedResult, nil } } // translateName changes the scheme name in a `cockroach://` URL to // `postgres://`, as that is what lib/pq will expect. func translateName(name string) string { if strings.HasPrefix(name, "cockroach://") { name = "postgres://" + name[12:] } return name } // NeedsRetry checks if the error code means a retry is needed, // borrowed from code in crdb. func NeedsRetry(err error) bool { code := errCode(err) // 57P01 occurs when a CRDB node rejoins the cluster but is not ready to accept connections // CRDB support recommended a retry at this point // Support ticket: https://support.cockroachlabs.com/hc/en-us/requests/5510 // TODO re-evaluate this if support provides a better solution return code == "40001" || code == "CR000" || code == "57P01" } // borrowed from crdb func errCode(err error) string { switch t := errorCause(err).(type) { case *pq.Error: return string(t.Code) default: return "" } } func errorCause(err error) error { for err != nil { cause := errors.Unwrap(err) if cause == nil { break } err = cause } return err } // Assert that Driver satisfies DriverContext. var _ driver.DriverContext = &Driver{} func init() { sql.Register("cockroach", &Driver{}) }