storj/private/dbutil/utccheck/db.go
Jeff Wendling fb8e78132d storagenodedb: reenable utccheck in tests
Change-Id: If7d64dd4ae58e4b656ff9122ae3195b2a5173cb3
2019-12-10 23:17:14 +00:00

300 lines
6.9 KiB
Go

// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package utccheck
import (
"context"
"database/sql/driver"
"time"
"github.com/zeebo/errs"
)
// Connector wraps a driver.Connector with utc checks.
type Connector struct {
connector driver.Connector
}
// WrapConnector wraps a driver.Connector with utc checks.
func WrapConnector(connector driver.Connector) *Connector {
return &Connector{connector: connector}
}
// Unwrap returns the underlying driver.Connector.
func (c *Connector) Unwrap() driver.Connector { return c.connector }
// Connect returns a wrapped driver.Conn with utc checks.
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
conn, err := c.connector.Connect(ctx)
if err != nil {
return nil, errs.Wrap(err)
}
return WrapConn(conn), nil
}
// Driver returns a wrapped driver.Driver with utc checks.
func (c *Connector) Driver() driver.Driver {
return WrapDriver(c.connector.Driver())
}
//
// driver
//
// Driver wraps a driver.Driver with utc checks.
type Driver struct {
driver driver.Driver
}
// WrapDriver wraps a driver.Driver with utc checks.
func WrapDriver(driver driver.Driver) *Driver {
return &Driver{driver: driver}
}
// Unwrap returns the underlying driver.Driver.
func (d *Driver) Unwrap() driver.Driver { return d.driver }
// Open returns a wrapped driver.Conn with utc checks.
func (d *Driver) Open(name string) (driver.Conn, error) {
conn, err := d.driver.Open(name)
if err != nil {
return nil, errs.Wrap(err)
}
return WrapConn(conn), nil
}
//
// conn
//
// Conn wraps a driver.Conn with utc checks.
type Conn struct {
conn driver.Conn
}
// WrapConn wraps a driver.Conn with utc checks.
func WrapConn(conn driver.Conn) *Conn {
return &Conn{conn: conn}
}
// Unwrap returns the underlying driver.Conn.
func (c *Conn) Unwrap() driver.Conn { return c.conn }
// Close closes the conn.
func (c *Conn) Close() error {
return c.conn.Close()
}
// Ping implements driver.Pinger.
func (c *Conn) Ping(ctx context.Context) error {
// sqlite3 implements this
return c.conn.(driver.Pinger).Ping(ctx)
}
// Begin returns a wrapped driver.Tx with utc checks.
func (c *Conn) Begin() (driver.Tx, error) {
//lint:ignore SA1019 deprecated is fine. this is a wrapper.
//nolint
tx, err := c.conn.Begin()
if err != nil {
return nil, errs.Wrap(err)
}
return WrapTx(tx), nil
}
// BeginTx returns a wrapped driver.Tx with utc checks.
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
// sqlite3 implements this
tx, err := c.conn.(driver.ConnBeginTx).BeginTx(ctx, opts)
if err != nil {
return nil, errs.Wrap(err)
}
return WrapTx(tx), nil
}
// Query checks the arguments for non-utc timestamps and returns the result.
func (c *Conn) Query(query string, args []driver.Value) (driver.Rows, error) {
if err := utcCheckArgs(args); err != nil {
return nil, err
}
// sqlite3 implements this
//
//lint:ignore SA1019 deprecated is fine. this is a wrapper.
//nolint
return c.conn.(driver.Queryer).Query(query, args)
}
// QueryContext checks the arguments for non-utc timestamps and returns the result.
func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
if err := utcCheckNamedArgs(args); err != nil {
return nil, err
}
// sqlite3 implements this
return c.conn.(driver.QueryerContext).QueryContext(ctx, query, args)
}
// Exec checks the arguments for non-utc timestamps and returns the result.
func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) {
if err := utcCheckArgs(args); err != nil {
return nil, err
}
// sqlite3 implements this
//
//lint:ignore SA1019 deprecated is fine. this is a wrapper.
//nolint
return c.conn.(driver.Execer).Exec(query, args)
}
// ExecContext checks the arguments for non-utc timestamps and returns the result.
func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if err := utcCheckNamedArgs(args); err != nil {
return nil, err
}
// sqlite3 implements this
return c.conn.(driver.ExecerContext).ExecContext(ctx, query, args)
}
// Prepare returns a wrapped driver.Stmt with utc checks.
func (c *Conn) Prepare(query string) (driver.Stmt, error) {
stmt, err := c.conn.Prepare(query)
if err != nil {
return nil, errs.Wrap(err)
}
return WrapStmt(stmt), nil
}
// PrepareContext checks the arguments for non-utc timestamps and returns the result.
func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
// sqlite3 implements this
stmt, err := c.conn.(driver.ConnPrepareContext).PrepareContext(ctx, query)
if err != nil {
return nil, errs.Wrap(err)
}
return WrapStmt(stmt), nil
}
//
// stmt
//
// Stmt wraps a driver.Stmt with utc checks.
type Stmt struct {
stmt driver.Stmt
}
// WrapStmt wraps a driver.Stmt with utc checks.
func WrapStmt(stmt driver.Stmt) *Stmt {
return &Stmt{stmt: stmt}
}
// Unwrap returns the underlying driver.Stmt.
func (s *Stmt) Unwrap() driver.Stmt { return s.stmt }
// Close closes the stmt.
func (s *Stmt) Close() error {
return s.stmt.Close()
}
// NumInput returns the number of inputs to the stmt.
func (s *Stmt) NumInput() int {
return s.stmt.NumInput()
}
// Exec checks the arguments for non-utc timestamps and returns the result.
func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
if err := utcCheckArgs(args); err != nil {
return nil, errs.Wrap(err)
}
//lint:ignore SA1019 deprecated is fine. this is a wrapper.
//nolint
return s.stmt.Exec(args)
}
// Query checks the arguments for non-utc timestamps and returns the result.
func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
if err := utcCheckArgs(args); err != nil {
return nil, errs.Wrap(err)
}
//lint:ignore SA1019 deprecated is fine. this is a wrapper.
//nolint
return s.stmt.Query(args)
}
//
// tx
//
// Tx wraps a driver.Tx with utc checks.
type Tx struct {
tx driver.Tx
}
// WrapTx wraps a driver.Tx with utc checks.
func WrapTx(tx driver.Tx) *Tx {
return &Tx{tx: tx}
}
// Unwrap returns the underlying driver.Tx.
func (t *Tx) Unwrap() driver.Tx { return t.tx }
// Commit commits the tx.
func (t *Tx) Commit() error {
return t.tx.Commit()
}
// Rollback rolls the tx back.
func (t *Tx) Rollback() error {
return t.tx.Rollback()
}
//
// helpers
//
func utcCheckArg(n int, arg interface{}) error {
var t time.Time
var ok bool
switch a := arg.(type) {
case time.Time:
t, ok = a, true
case *time.Time:
if a != nil {
t, ok = *a, true
}
}
if !ok {
return nil
} else if loc := t.Location(); loc != time.UTC {
return errs.New("invalid timezone on argument %d: %v", n, loc)
} else {
return nil
}
}
func utcCheckNamedArgs(args []driver.NamedValue) error {
for n, arg := range args {
if err := utcCheckArg(n, arg.Value); err != nil {
return err
}
}
return nil
}
func utcCheckArgs(args []driver.Value) error {
for n, arg := range args {
if err := utcCheckArg(n, arg); err != nil {
return err
}
}
return nil
}