Skip to content

Commit bfd1ba8

Browse files
authored
feat: generate only the required nullable helpers (#9)
* feat: generate only the required nullable helpers * chore: remove unnecesarry break statements * fix: remove unnecessary ToCamel call * chore: remove redundant type
1 parent 0fc94ab commit bfd1ba8

File tree

4 files changed

+89
-41
lines changed

4 files changed

+89
-41
lines changed

internal/codegen/common.go

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ package codegen
22

33
import (
44
"fmt"
5-
"github.com/tandemdude/sqlc-gen-java/internal/core"
65
"os"
76
"strings"
7+
8+
"github.com/tandemdude/sqlc-gen-java/internal/core"
89
)
910

1011
type IndentStringBuilder struct {
@@ -35,36 +36,48 @@ func (b *IndentStringBuilder) writeSqlcHeader() {
3536
b.WriteString("// sqlc-gen-java " + core.PluginVersion + "\n")
3637
}
3738

38-
func (b *IndentStringBuilder) writeQueriesBoilerplate(nonNullAnnotation, nullableAnnotation string) {
39-
methodTypes := [][]string{
40-
{"Integer", "Int"},
41-
{"Long", "Long"},
42-
{"Float", "Float"},
43-
{"Double", "Double"},
44-
{"Boolean", "Boolean"},
39+
type nullableHelper struct {
40+
ShouldOutput bool
41+
ReturnType string
42+
ArgType string
43+
}
44+
45+
func (b *IndentStringBuilder) writeNullableHelpers(nullableHelpers core.NullableHelpers, nonNullAnnotation, nullableAnnotation string) {
46+
methodTypes := []nullableHelper{
47+
{nullableHelpers.Int, "Integer", "Int"},
48+
{nullableHelpers.Long, "Long", "Long"},
49+
{nullableHelpers.Float, "Float", "Float"},
50+
{nullableHelpers.Double, "Double", "Double"},
51+
{nullableHelpers.Boolean, "Boolean", "Boolean"},
4552
}
4653

4754
for _, methodType := range methodTypes {
55+
if !methodType.ShouldOutput {
56+
continue
57+
}
58+
4859
b.WriteIndentedString(1, fmt.Sprintf(
4960
"private static %s get%s(%s rs, int col) throws SQLException {\n",
50-
core.Annotate(methodType[0], nullableAnnotation),
51-
methodType[1],
61+
core.Annotate(methodType.ReturnType, nullableAnnotation),
62+
methodType.ArgType,
5263
core.Annotate("ResultSet", nonNullAnnotation),
5364
))
5465
b.WriteIndentedString(2, fmt.Sprintf(
5566
"var colVal = rs.get%s(col); return rs.wasNull() ? null : colVal;\n",
56-
methodType[1],
67+
methodType.ArgType,
5768
))
5869
b.WriteIndentedString(1, "}\n")
5970
}
6071

61-
b.WriteIndentedString(1, fmt.Sprintf(
62-
"private static <T> %s getList(%s rs, int col, Class<T[]> as) throws SQLException {\n",
63-
core.Annotate("List<T>", nullableAnnotation),
64-
core.Annotate("ResultSet", nonNullAnnotation),
65-
))
66-
b.WriteIndentedString(2, "var colVal = rs.getArray(col); return colVal == null ? null : Arrays.asList(as.cast(colVal.getArray()));\n")
67-
b.WriteIndentedString(1, "}\n")
72+
if nullableHelpers.List {
73+
b.WriteIndentedString(1, fmt.Sprintf(
74+
"private static <T> %s getList(%s rs, int col, Class<T[]> as) throws SQLException {\n",
75+
core.Annotate("List<T>", nullableAnnotation),
76+
core.Annotate("ResultSet", nonNullAnnotation),
77+
))
78+
b.WriteIndentedString(2, "var colVal = rs.getArray(col); return colVal == null ? null : Arrays.asList(as.cast(colVal.getArray()));\n")
79+
b.WriteIndentedString(1, "}\n")
80+
}
6881
}
6982

7083
func (b *IndentStringBuilder) writeParameter(javaType core.JavaType, name, nonNullAnnotation, nullableAnnotation string) ([]string, error) {

internal/codegen/queries.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func completeMethodBody(sb *IndentStringBuilder, q core.Query, embeddedModels co
119119
}
120120
}
121121

122-
func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Query, embeddedModels core.EmbeddedModels) (string, []byte, error) {
122+
func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Query, embeddedModels core.EmbeddedModels, nullableHelpers core.NullableHelpers) (string, []byte, error) {
123123
className := strcase.ToCamel(strings.TrimSuffix(queryFilename, ".sql"))
124124
className = strings.TrimSuffix(className, "Query")
125125
className = strings.TrimSuffix(className, "Queries")
@@ -161,7 +161,7 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q
161161

162162
// boilerplate methods to allow for getting null primitive values
163163
body.WriteString("\n")
164-
body.writeQueriesBoilerplate(nonNullAnnotation, nullableAnnotation)
164+
body.writeNullableHelpers(nullableHelpers, nonNullAnnotation, nullableAnnotation)
165165

166166
for _, q := range queries {
167167
body.WriteString("\n")

internal/core/models.go

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,14 @@ type QueryArg struct {
5151

5252
// TODO - enum types
5353

54-
var literalBindTypes = []string{"Integer", "Long", "Short", "String", "Boolean", "Float", "Double", "BigDecimal", "byte[]"}
55-
var typeToMethodRename = map[string]string{
56-
"Integer": "Int",
57-
"byte[]": "Bytes",
58-
}
54+
var (
55+
literalBindTypes = []string{"Integer", "Long", "Short", "String", "Boolean", "Float", "Double", "BigDecimal", "byte[]"}
56+
typeToMethodRename = map[string]string{
57+
"Integer": "Int",
58+
"byte[]": "Bytes",
59+
}
60+
)
61+
5962
var typeToJavaSqlTypeConst = map[string]string{
6063
"Integer": "INTEGER",
6164
"Long": "BIGINT",
@@ -137,5 +140,16 @@ type Query struct {
137140
Returns []QueryReturn
138141
}
139142

140-
type Queries map[string][]Query
141-
type EmbeddedModels map[string][]QueryReturn
143+
type NullableHelpers struct {
144+
Int bool
145+
Long bool
146+
Float bool
147+
Double bool
148+
Boolean bool
149+
List bool
150+
}
151+
152+
type (
153+
Queries map[string][]Query
154+
EmbeddedModels map[string][]QueryReturn
155+
)

internal/gen.go

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"encoding/json"
66
"errors"
77
"fmt"
8-
"github.com/tandemdude/sqlc-gen-java/internal/inflection"
98
"regexp"
109
"slices"
1110
"strconv"
@@ -16,6 +15,7 @@ import (
1615
"github.com/sqlc-dev/plugin-sdk-go/sdk"
1716
"github.com/tandemdude/sqlc-gen-java/internal/codegen"
1817
"github.com/tandemdude/sqlc-gen-java/internal/core"
18+
"github.com/tandemdude/sqlc-gen-java/internal/inflection"
1919
"github.com/tandemdude/sqlc-gen-java/internal/sql_types"
2020
)
2121

@@ -45,9 +45,8 @@ func fixQueryPlaceholders(engine, query string) (string, error) {
4545
return newQuery, nil
4646
}
4747

48-
func parseQueryReturn(tcf sql_types.TypeConversionFunc, col *plugin.Column) (*core.QueryReturn, error) {
49-
name := strcase.ToCamel(col.Name)
50-
javaType, err := tcf(col.Type)
48+
func parseQueryReturn(tcf sql_types.TypeConversionFunc, nullableHelpers *core.NullableHelpers, col *plugin.Column) (*core.QueryReturn, error) {
49+
strJavaType, err := tcf(col.Type)
5150
if err != nil {
5251
return nil, err
5352
}
@@ -56,14 +55,35 @@ func parseQueryReturn(tcf sql_types.TypeConversionFunc, col *plugin.Column) (*co
5655
return nil, fmt.Errorf("multidimensional arrays are not supported, store JSON instead")
5756
}
5857

58+
javaType := core.JavaType{
59+
SqlType: sdk.DataType(col.Type),
60+
Type: strJavaType,
61+
IsList: col.IsArray,
62+
IsNullable: !col.NotNull,
63+
}
64+
65+
if javaType.IsNullable {
66+
if javaType.IsList {
67+
nullableHelpers.List = true
68+
} else {
69+
switch strJavaType {
70+
case "Integer":
71+
nullableHelpers.Int = true
72+
case "Long":
73+
nullableHelpers.Long = true
74+
case "Float":
75+
nullableHelpers.Float = true
76+
case "Double":
77+
nullableHelpers.Double = true
78+
case "Boolean":
79+
nullableHelpers.Boolean = true
80+
}
81+
}
82+
}
83+
5984
return &core.QueryReturn{
60-
Name: strcase.ToLowerCamel(name),
61-
JavaType: core.JavaType{
62-
SqlType: sdk.DataType(col.Type),
63-
Type: javaType,
64-
IsList: col.IsArray,
65-
IsNullable: !col.NotNull,
66-
},
85+
Name: strcase.ToLowerCamel(col.Name),
86+
JavaType: javaType,
6787
}, nil
6888
}
6989

@@ -94,6 +114,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
94114

95115
var queries core.Queries = make(map[string][]core.Query)
96116
var embeddedModels core.EmbeddedModels = make(map[string][]core.QueryReturn)
117+
nullableHelpers := core.NullableHelpers{}
97118

98119
// parse the incoming generate request into our Queries type
99120
for _, query := range req.Queries {
@@ -135,7 +156,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
135156
for _, ret := range query.Columns {
136157
if ret.EmbedTable == nil {
137158
// normal types
138-
qr, err := parseQueryReturn(typeConversionFunc, ret)
159+
qr, err := parseQueryReturn(typeConversionFunc, &nullableHelpers, ret)
139160
if err != nil {
140161
return nil, errors.Join(errors.New("failed to parse query return column"), err)
141162
}
@@ -179,7 +200,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
179200
if _, ok := embeddedModels[modelName]; !ok {
180201
var modelParams []core.QueryReturn
181202
for _, c := range table.Columns {
182-
qr, err := parseQueryReturn(typeConversionFunc, c)
203+
qr, err := parseQueryReturn(typeConversionFunc, &nullableHelpers, c)
183204
if err != nil {
184205
return nil, errors.Join(errors.New("failed to parse query return column"), err)
185206
}
@@ -227,7 +248,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
227248
slices.SortFunc(queries[file], func(a, b core.Query) int { return strings.Compare(a.MethodName, b.MethodName) })
228249

229250
// build the queries file contents
230-
fileName, fileContents, err := codegen.BuildQueriesFile(conf, file, queries[file], embeddedModels)
251+
fileName, fileContents, err := codegen.BuildQueriesFile(conf, file, queries[file], embeddedModels, nullableHelpers)
231252
if err != nil {
232253
return nil, err
233254
}

0 commit comments

Comments
 (0)