diff --git a/internal/codegen/common.go b/internal/codegen/common.go index 37149e6..87dccc2 100644 --- a/internal/codegen/common.go +++ b/internal/codegen/common.go @@ -2,9 +2,10 @@ package codegen import ( "fmt" - "github.com/tandemdude/sqlc-gen-java/internal/core" "os" "strings" + + "github.com/tandemdude/sqlc-gen-java/internal/core" ) type IndentStringBuilder struct { @@ -35,36 +36,48 @@ func (b *IndentStringBuilder) writeSqlcHeader() { b.WriteString("// sqlc-gen-java " + core.PluginVersion + "\n") } -func (b *IndentStringBuilder) writeQueriesBoilerplate(nonNullAnnotation, nullableAnnotation string) { - methodTypes := [][]string{ - {"Integer", "Int"}, - {"Long", "Long"}, - {"Float", "Float"}, - {"Double", "Double"}, - {"Boolean", "Boolean"}, +type nullableHelper struct { + ShouldOutput bool + ReturnType string + ArgType string +} + +func (b *IndentStringBuilder) writeNullableHelpers(nullableHelpers core.NullableHelpers, nonNullAnnotation, nullableAnnotation string) { + methodTypes := []nullableHelper{ + {nullableHelpers.Int, "Integer", "Int"}, + {nullableHelpers.Long, "Long", "Long"}, + {nullableHelpers.Float, "Float", "Float"}, + {nullableHelpers.Double, "Double", "Double"}, + {nullableHelpers.Boolean, "Boolean", "Boolean"}, } for _, methodType := range methodTypes { + if !methodType.ShouldOutput { + continue + } + b.WriteIndentedString(1, fmt.Sprintf( "private static %s get%s(%s rs, int col) throws SQLException {\n", - core.Annotate(methodType[0], nullableAnnotation), - methodType[1], + core.Annotate(methodType.ReturnType, nullableAnnotation), + methodType.ArgType, core.Annotate("ResultSet", nonNullAnnotation), )) b.WriteIndentedString(2, fmt.Sprintf( "var colVal = rs.get%s(col); return rs.wasNull() ? null : colVal;\n", - methodType[1], + methodType.ArgType, )) b.WriteIndentedString(1, "}\n") } - b.WriteIndentedString(1, fmt.Sprintf( - "private static %s getList(%s rs, int col, Class as) throws SQLException {\n", - core.Annotate("List", nullableAnnotation), - core.Annotate("ResultSet", nonNullAnnotation), - )) - b.WriteIndentedString(2, "var colVal = rs.getArray(col); return colVal == null ? null : Arrays.asList(as.cast(colVal.getArray()));\n") - b.WriteIndentedString(1, "}\n") + if nullableHelpers.List { + b.WriteIndentedString(1, fmt.Sprintf( + "private static %s getList(%s rs, int col, Class as) throws SQLException {\n", + core.Annotate("List", nullableAnnotation), + core.Annotate("ResultSet", nonNullAnnotation), + )) + b.WriteIndentedString(2, "var colVal = rs.getArray(col); return colVal == null ? null : Arrays.asList(as.cast(colVal.getArray()));\n") + b.WriteIndentedString(1, "}\n") + } } func (b *IndentStringBuilder) writeParameter(javaType core.JavaType, name, nonNullAnnotation, nullableAnnotation string) ([]string, error) { diff --git a/internal/codegen/queries.go b/internal/codegen/queries.go index bf2fcd0..3fffb08 100644 --- a/internal/codegen/queries.go +++ b/internal/codegen/queries.go @@ -119,7 +119,7 @@ func completeMethodBody(sb *IndentStringBuilder, q core.Query, embeddedModels co } } -func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Query, embeddedModels core.EmbeddedModels) (string, []byte, error) { +func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Query, embeddedModels core.EmbeddedModels, nullableHelpers core.NullableHelpers) (string, []byte, error) { className := strcase.ToCamel(strings.TrimSuffix(queryFilename, ".sql")) className = strings.TrimSuffix(className, "Query") className = strings.TrimSuffix(className, "Queries") @@ -161,7 +161,7 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q // boilerplate methods to allow for getting null primitive values body.WriteString("\n") - body.writeQueriesBoilerplate(nonNullAnnotation, nullableAnnotation) + body.writeNullableHelpers(nullableHelpers, nonNullAnnotation, nullableAnnotation) for _, q := range queries { body.WriteString("\n") diff --git a/internal/core/models.go b/internal/core/models.go index fc08fe4..4ad8a6c 100644 --- a/internal/core/models.go +++ b/internal/core/models.go @@ -51,11 +51,14 @@ type QueryArg struct { // TODO - enum types -var literalBindTypes = []string{"Integer", "Long", "Short", "String", "Boolean", "Float", "Double", "BigDecimal", "byte[]"} -var typeToMethodRename = map[string]string{ - "Integer": "Int", - "byte[]": "Bytes", -} +var ( + literalBindTypes = []string{"Integer", "Long", "Short", "String", "Boolean", "Float", "Double", "BigDecimal", "byte[]"} + typeToMethodRename = map[string]string{ + "Integer": "Int", + "byte[]": "Bytes", + } +) + var typeToJavaSqlTypeConst = map[string]string{ "Integer": "INTEGER", "Long": "BIGINT", @@ -137,5 +140,16 @@ type Query struct { Returns []QueryReturn } -type Queries map[string][]Query -type EmbeddedModels map[string][]QueryReturn +type NullableHelpers struct { + Int bool + Long bool + Float bool + Double bool + Boolean bool + List bool +} + +type ( + Queries map[string][]Query + EmbeddedModels map[string][]QueryReturn +) diff --git a/internal/gen.go b/internal/gen.go index 3323542..3f9a8fc 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tandemdude/sqlc-gen-java/internal/inflection" "regexp" "slices" "strconv" @@ -16,6 +15,7 @@ import ( "github.com/sqlc-dev/plugin-sdk-go/sdk" "github.com/tandemdude/sqlc-gen-java/internal/codegen" "github.com/tandemdude/sqlc-gen-java/internal/core" + "github.com/tandemdude/sqlc-gen-java/internal/inflection" "github.com/tandemdude/sqlc-gen-java/internal/sql_types" ) @@ -45,9 +45,8 @@ func fixQueryPlaceholders(engine, query string) (string, error) { return newQuery, nil } -func parseQueryReturn(tcf sql_types.TypeConversionFunc, col *plugin.Column) (*core.QueryReturn, error) { - name := strcase.ToCamel(col.Name) - javaType, err := tcf(col.Type) +func parseQueryReturn(tcf sql_types.TypeConversionFunc, nullableHelpers *core.NullableHelpers, col *plugin.Column) (*core.QueryReturn, error) { + strJavaType, err := tcf(col.Type) if err != nil { return nil, err } @@ -56,14 +55,35 @@ func parseQueryReturn(tcf sql_types.TypeConversionFunc, col *plugin.Column) (*co return nil, fmt.Errorf("multidimensional arrays are not supported, store JSON instead") } + javaType := core.JavaType{ + SqlType: sdk.DataType(col.Type), + Type: strJavaType, + IsList: col.IsArray, + IsNullable: !col.NotNull, + } + + if javaType.IsNullable { + if javaType.IsList { + nullableHelpers.List = true + } else { + switch strJavaType { + case "Integer": + nullableHelpers.Int = true + case "Long": + nullableHelpers.Long = true + case "Float": + nullableHelpers.Float = true + case "Double": + nullableHelpers.Double = true + case "Boolean": + nullableHelpers.Boolean = true + } + } + } + return &core.QueryReturn{ - Name: strcase.ToLowerCamel(name), - JavaType: core.JavaType{ - SqlType: sdk.DataType(col.Type), - Type: javaType, - IsList: col.IsArray, - IsNullable: !col.NotNull, - }, + Name: strcase.ToLowerCamel(col.Name), + JavaType: javaType, }, nil } @@ -94,6 +114,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat var queries core.Queries = make(map[string][]core.Query) var embeddedModels core.EmbeddedModels = make(map[string][]core.QueryReturn) + nullableHelpers := core.NullableHelpers{} // parse the incoming generate request into our Queries type for _, query := range req.Queries { @@ -135,7 +156,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat for _, ret := range query.Columns { if ret.EmbedTable == nil { // normal types - qr, err := parseQueryReturn(typeConversionFunc, ret) + qr, err := parseQueryReturn(typeConversionFunc, &nullableHelpers, ret) if err != nil { return nil, errors.Join(errors.New("failed to parse query return column"), err) } @@ -179,7 +200,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat if _, ok := embeddedModels[modelName]; !ok { var modelParams []core.QueryReturn for _, c := range table.Columns { - qr, err := parseQueryReturn(typeConversionFunc, c) + qr, err := parseQueryReturn(typeConversionFunc, &nullableHelpers, c) if err != nil { return nil, errors.Join(errors.New("failed to parse query return column"), err) } @@ -227,7 +248,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat slices.SortFunc(queries[file], func(a, b core.Query) int { return strings.Compare(a.MethodName, b.MethodName) }) // build the queries file contents - fileName, fileContents, err := codegen.BuildQueriesFile(conf, file, queries[file], embeddedModels) + fileName, fileContents, err := codegen.BuildQueriesFile(conf, file, queries[file], embeddedModels, nullableHelpers) if err != nil { return nil, err }