diff --git a/.gitignore b/.gitignore index ecfc354..e74c392 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,5 @@ cpu.out *.sublime-project *.sublime-workspace + +.idea diff --git a/db.go b/db.go index fca8ff4..a14664e 100644 --- a/db.go +++ b/db.go @@ -1,6 +1,7 @@ package jet import ( + "context" "database/sql" ) @@ -18,22 +19,29 @@ type Db struct { // Defaults to SnakeCaseConverter. ColumnConverter ColumnConverter - driver string - source string - lru *lru + driver string + source string + lru *lru + skipPreparedStmts bool } // Open opens a new database connection. -func Open(driverName, dataSourceName string) (*Db, error) { - db, err := sql.Open(driverName, dataSourceName) +func Open(driverName, dataSourceName string, usePreparedStmts bool, preparedStmtCacheSize int) (*Db, error) { + return OpenFunc(driverName, dataSourceName, sql.Open, usePreparedStmts, preparedStmtCacheSize) +} + +// OpenFunc opens a new database connection by using the passed `fn`. +func OpenFunc(driverName, dataSourceName string, fn func(string, string) (*sql.DB, error), usePreparedStmts bool, preparedStmtCacheSize int) (*Db, error) { + db, err := fn(driverName, dataSourceName) if err != nil { return nil, err } j := &Db{ - ColumnConverter: SnakeCaseConverter, // default - driver: driverName, - source: dataSourceName, - lru: newLru(), + ColumnConverter: SnakeCaseConverter, // default + driver: driverName, + source: dataSourceName, + lru: newLru(preparedStmtCacheSize), + skipPreparedStmts: usePreparedStmts, } j.DB = db @@ -64,5 +72,14 @@ func (db *Db) Begin() (*Tx, error) { // Query creates a prepared query that can be run with Rows or Run. func (db *Db) Query(query string, args ...interface{}) Runnable { - return newQuery(db, db, query, args...) + return db.QueryContext(context.Background(), query, args...) +} + +// QueryContext creates a prepared query that can be run with Rows or Run. +func (db *Db) QueryContext(ctx context.Context, query string, args ...interface{}) Runnable { + return newQuery(ctx, db.skipPreparedStmts, db, db, query, args...) +} + +func (db *Db) CacheSize() int { + return db.lru.size() } diff --git a/lru.go b/lru.go index 759a3d9..541a4b0 100644 --- a/lru.go +++ b/lru.go @@ -19,9 +19,9 @@ type lruItem struct { stmt *sql.Stmt } -func newLru() *lru { +func newLru(maxItems int) *lru { return &lru{ - maxItems: 500, + maxItems: maxItems, keys: make(map[string]*list.Element), list: list.New(), } @@ -79,6 +79,10 @@ func (c *lru) clean() { } } +func (c *lru) size() int { + return c.list.Len() +} + // makeKey hashes the key to save some bytes func makeKey(k string) string { buffer := sha1.New() diff --git a/mapper.go b/mapper.go index 96f2ec5..cfd065d 100644 --- a/mapper.go +++ b/mapper.go @@ -20,9 +20,34 @@ func (m *mapper) unpack(keys []string, values []interface{}, out interface{}) er return m.unpackValue(keys, values, val) } +func isNil(val interface{}) bool { + if val == nil { + return true + } + if reflect.ValueOf(val).IsZero() { + return true + } + if reflect.ValueOf(val).Kind() == reflect.Ptr { + if reflect.ValueOf(val).Elem().Kind() == reflect.Struct || reflect.ValueOf(val).Elem().Kind() == reflect.Interface { + return reflect.ValueOf(val).Elem().IsNil() + } + } + + return false +} + func (m *mapper) unpackValue(keys []string, values []interface{}, out reflect.Value) error { switch out.Interface().(type) { case ComplexValue: + if isNil(values[0]) { + if out.IsZero() { + return nil + } + if out.CanSet() { + out.Set(reflect.Zero(out.Type())) + return nil + } + } if out.IsNil() { out.Set(reflect.New(out.Type().Elem())) } @@ -82,6 +107,18 @@ func (m *mapper) unpackStruct(keys []string, values []interface{}, out reflect.V convKey = m.conv.ColumnToFieldName(k) } field := out.FieldByName(convKey) + + // If the field is not found it can mean that we don't want it or that + // we have special case like UserID, UUID, userUUID + // So fix the name and try again + if !field.IsValid() { + convKey = strings.Replace(convKey, "Uuid", "UUID", -1) + convKey = strings.Replace(convKey, "Id", "ID", -1) + convKey = strings.Replace(convKey, "Ip", "IP", -1) + convKey = strings.Replace(convKey, "Url", "URL", -1) + field = out.FieldByName(convKey) + } + if field.IsValid() { m.unpackValue(nil, values[i:i+1], field) } diff --git a/mapper_test.go b/mapper_test.go index 8ea56ff..a3c0949 100644 --- a/mapper_test.go +++ b/mapper_test.go @@ -73,8 +73,13 @@ func (c *custom) Decode(v interface{}) error { } s, ok := v.(string) if ok { - c.a = string(s[0]) - c.b = string(s[1]) + if len(s) > 1 { + c.a = string(s[0]) + c.b = string(s[1]) + } else { + c.a = "" + c.b = "" + } } return nil } @@ -181,6 +186,170 @@ func TestUnpackStruct(t *testing.T) { } } +func TestUnpackStructExistingValueToEmpty(t *testing.T) { + keys := []string{"m"} + vals := []interface{}{ + "", + } + mppr := &mapper{ + conv: SnakeCaseConverter, + } + + var v struct { + M plainCustom + } + v.M = "abc" + + err := mppr.unpack(keys, vals, &v) + if err != nil { + t.Fatal(err) + } + if v.M != "" { + t.Fatal(v.M) + } +} + +func TestUnpackStructEmptyToEmpty(t *testing.T) { + keys := []string{"m"} + vals := []interface{}{ + "", + } + mppr := &mapper{ + conv: SnakeCaseConverter, + } + + var v struct { + M plainCustom + } + + err := mppr.unpack(keys, vals, &v) + if err != nil { + t.Fatal(err) + } + if v.M != "" { + t.Fatal(v.M) + } +} + +func TestUnpackStructExistingValueToNil(t *testing.T) { + keys := []string{"j"} + vals := []interface{}{ + nil, + } + mppr := &mapper{ + conv: SnakeCaseConverter, + } + + var v struct { + J *custom + } + v.J = &custom{a: "a", b: "b"} + + err := mppr.unpack(keys, vals, &v) + if err != nil { + t.Fatal(err) + } + if v.J != nil { + t.Fatal(v.J) + } +} + +func TestUnpackStructExistingValueNonPtrToEmpty(t *testing.T) { + keys := []string{"j"} + vals := []interface{}{ + "", + } + mppr := &mapper{ + conv: SnakeCaseConverter, + } + + var v struct { + J custom + } + v.J = custom{a: "a", b: "b"} + + err := mppr.unpack(keys, vals, &v) + if err != nil { + t.Fatal(err) + } + if v.J.a != "" || v.J.b != "" { + t.Fatal(v.J) + } +} + +func TestUnpackStructComplexExistingValueToEmpty(t *testing.T) { + keys := []string{"j"} + vals := []interface{}{ + "", + } + mppr := &mapper{ + conv: SnakeCaseConverter, + } + + var v struct { + J *custom + } + v.J = &custom{a: "a", b: "b"} + + err := mppr.unpack(keys, vals, &v) + if err != nil { + t.Fatal(err) + } + if v.J != nil { + t.Fatal(v.J) + } +} + + +func TestUnpackStructNilLikeDBQuery(t *testing.T) { + keys := []string{"j"} + vals := make([]interface{}, 0, len(keys)) + for i := 0; i < cap(vals); i++ { + vals = append(vals, new(interface{})) + } + mppr := &mapper{ + conv: SnakeCaseConverter, + } + + var v struct { + J *custom + } + v.J = &custom{ + a: "a", b: "b", + } + + err := mppr.unpack(keys, vals, &v) + if err != nil { + t.Fatal(err) + } + if v.J != nil { + t.Fatal(v.J) + } +} + +func TestUnpackStructNilComplexToNil(t *testing.T) { + keys := []string{"j"} + vals := []interface{}{ + nil, + } + mppr := &mapper{ + conv: SnakeCaseConverter, + } + + var v struct { + J *custom + } + + err := mppr.unpack(keys, vals, &v) + if err != nil { + t.Fatal(err) + } + if v.J != nil { + t.Fatal(v.J) + } +} + + func TestUnpackMap(t *testing.T) { keys := []string{"ab_c", "c_d", "e"} vals := []interface{}{int64(9), "hello", "unsettable"} diff --git a/query.go b/query.go index 3ec7952..1cd0907 100644 --- a/query.go +++ b/query.go @@ -1,27 +1,32 @@ package jet import ( + "context" "database/sql" "sync" ) type jetQuery struct { - m sync.Mutex - db *Db - qo queryObject - id string - query string - args []interface{} + m sync.Mutex + db *Db + qo queryObject + id string + query string + args []interface{} + ctx context.Context + skipPreparedStmts bool } // newQuery initiates a new query for the provided query object (either *sql.Tx or *sql.DB) -func newQuery(qo queryObject, db *Db, query string, args ...interface{}) *jetQuery { +func newQuery(ctx context.Context, skipPreparedStmts bool, qo queryObject, db *Db, query string, args ...interface{}) *jetQuery { return &jetQuery{ - qo: qo, - db: db, - id: newQueryId(), - query: query, - args: args, + qo: qo, + db: db, + id: newQueryId(), + query: query, + args: args, + ctx: ctx, + skipPreparedStmts: skipPreparedStmts, } } @@ -33,12 +38,19 @@ func (q *jetQuery) Rows(v interface{}) (err error) { q.m.Lock() defer q.m.Unlock() + if q.ctx == nil { + q.ctx = context.Background() + } + // disable lru in transactions useLru := true switch q.qo.(type) { case *sql.Tx: useLru = false } + if q.skipPreparedStmts { + useLru = false + } query, args := substituteMapAndArrayMarks(q.query, q.args...) @@ -67,28 +79,47 @@ func (q *jetQuery) Rows(v interface{}) (err error) { } // prepare statement - stmt, ok := q.db.lru.get(query) - if !useLru || !ok { - stmt, err = q.qo.Prepare(query) + var rows *sql.Rows + if q.skipPreparedStmts { + conn, err := q.db.DB.Conn(q.ctx) if err != nil { return err } - if useLru { - q.db.lru.put(query, stmt) + defer conn.Close() + + if v == nil { + _, err := conn.ExecContext(q.ctx, query, args...) + return err } - } - // If no rows need to be unpacked use Exec - if v == nil { - _, err := stmt.Exec(args...) - return err - } + rows, err = conn.QueryContext(q.ctx, query, args...) + } else { + stmt, ok := q.db.lru.get(query) + if !useLru || !ok { + stmt, err = q.qo.Prepare(query) + if err != nil { + return err + } + if useLru { + q.db.lru.put(query, stmt) + } else { + defer stmt.Close() + } + } + // If no rows need to be unpacked use Exec + if v == nil { + _, err := stmt.ExecContext(q.ctx, args...) + return err + } + + // run query + rows, err = stmt.QueryContext(q.ctx, args...) + if err != nil { + return err + } - // run query - rows, err := stmt.Query(args...) - if err != nil { - return err } + defer rows.Close() cols, err := rows.Columns() diff --git a/tx.go b/tx.go index 7124808..f47a908 100644 --- a/tx.go +++ b/tx.go @@ -1,7 +1,9 @@ package jet import ( + "context" "database/sql" + "errors" ) // Tx represents a transaction instance. @@ -14,11 +16,24 @@ type Tx struct { // Query creates a prepared query that can be run with Rows or Run. func (tx *Tx) Query(query string, args ...interface{}) Runnable { - q := newQuery(tx.tx, tx.db, query, args...) + return tx.QueryContext(context.Background(), query, args...) +} + +// QueryContext creates a prepared query that can be run with Rows or Run. +func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) Runnable { + q := newQuery(ctx, tx.tx, tx.db, query, args...) q.id = tx.qid return q } +// Exec calls Exec on the underlying sql.Tx. +func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { + if tx == nil || tx.tx == nil { + return nil, errors.New("jet: Exec called on nil transaction") + } + return tx.tx.Exec(query, args...) +} + // Commit commits the transaction func (tx *Tx) Commit() error { if tx.db.LogFunc != nil {