diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e6ab7bf..29fe462 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -3,7 +3,25 @@ name: CI on: [push, pull_request] jobs: + test-go: + name: Test Go + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: './go.mod' + + - name: Run tests + run: go test ./... + test-java: + needs: [test-go] + name: Test Java runs-on: ubuntu-latest diff --git a/README.md b/README.md index 84e3fea..cebda5b 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,11 @@ A WASM plugin for SQLC allowing the generation of Java code. > [!NOTE] > Only the `PostgreSQL` engine is supported currently. Support for `MySQL` is planned. - + ## Configuration Values | Name | Type | Required | Description | -|--------------------------|---------|----------|------------------------------------------------------------------------------------------------------------------------------------------| +| ------------------------ | ------- | -------- | ---------------------------------------------------------------------------------------------------------------------------------------- | | `package` | string | yes | The name of the package where the generated files will be located | | `query_parameter_limit` | integer | no | not yet implemented | | `indent_char` | string | no | The character to use to indent the code. Defaults to space `" "` | @@ -19,29 +19,31 @@ A WASM plugin for SQLC allowing the generation of Java code. ## Usage `sqlc.yaml` + ```yaml -version: '2' +version: "2" plugins: -- name: java - wasm: - url: TODO - sha256: TODO + - name: java + wasm: + url: TODO + sha256: TODO sql: -- schema: src/main/resources/postgresql/schema.sql - queries: src/main/resources/postgresql/queries.sql - engine: postgresql - codegen: - - out: src/main/java/com/example/postgresql - plugin: java - options: - package: com.example.postgresql + - schema: src/main/resources/postgresql/schema.sql + queries: src/main/resources/postgresql/queries.sql + engine: postgresql + codegen: + - out: src/main/java/com/example/postgresql + plugin: java + options: + package: com.example.postgresql ``` ## Building From Source -Building the plugin is very simple, just clone the repository and run the following command within the `plugin` directory: +Building the plugin is very simple, just clone the repository and run the following command: + ```bash -GOOS=wasip1 GOARCH=wasm go build -o ../sqlc-gen-java.wasm +GOOS=wasip1 GOARCH=wasm go build -o sqlc-gen-java.wasm plugin/main.go ``` A file `sqlc-gen-java.wasm` will be created in the repository root - you can then move it to your sqlc-enabled project @@ -51,13 +53,13 @@ You should ensure that the `sha256` value in your `sqlc.yaml` is correct for thi ## Planned Features -- `sqlc.embed()` support - `MySQL` support - `SQLite` support - Improved parameter naming - First-class support for `bytes`, `blob`, `bytea` datatypes **Tentative:** + - r2dbc support - Support for PostgreSQL enum types - copyfrom support where possible [ref](https://www.baeldung.com/jdbc-batch-processing) diff --git a/go.mod b/go.mod index 7498247..4c8aa4f 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,8 @@ require github.com/sqlc-dev/plugin-sdk-go v1.23.0 require ( github.com/golang/protobuf v1.5.3 // indirect - github.com/iancoleman/strcase v0.3.0 // indirect + github.com/iancoleman/strcase v0.3.0 + github.com/jinzhu/inflection v1.0.0 golang.org/x/net v0.14.0 // indirect golang.org/x/sys v0.11.0 // indirect golang.org/x/text v0.12.0 // indirect diff --git a/go.sum b/go.sum index 9f35755..a86813d 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/sqlc-dev/plugin-sdk-go v1.23.0 h1:iSeJhnXPlbDXlbzUEebw/DxsGzE9rdDJArl8Hvt0RMM= github.com/sqlc-dev/plugin-sdk-go v1.23.0/go.mod h1:I1r4THOfyETD+LI2gogN2LX8wCjwUZrgy/NU4In3llA= golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= diff --git a/internal/codegen/common.go b/internal/codegen/common.go index 63cb716..bf7d847 100644 --- a/internal/codegen/common.go +++ b/internal/codegen/common.go @@ -1,6 +1,7 @@ package codegen import ( + "fmt" "github.com/tandemdude/sqlc-gen-java/internal/core" "os" "strings" @@ -25,11 +26,36 @@ func (b *IndentStringBuilder) WriteIndentedString(level int, s string) int { return count } -func writeSqlcHeader(sb *IndentStringBuilder) { +func (b *IndentStringBuilder) writeSqlcHeader() { sqlcVersion := os.Getenv("SQLC_VERSION") - sb.WriteString("// Code generated by sqlc. DO NOT EDIT.\n") - sb.WriteString("// versions:\n") - sb.WriteString("// sqlc " + sqlcVersion + "\n") - sb.WriteString("// sqlc-gen-java " + core.PluginVersion + "\n") + b.WriteString("// Code generated by sqlc. DO NOT EDIT.\n") + b.WriteString("// versions:\n") + b.WriteString("// sqlc " + sqlcVersion + "\n") + b.WriteString("// sqlc-gen-java " + core.PluginVersion + "\n") +} + +func (b *IndentStringBuilder) writeQueriesBoilerplate(nonNullAnnotation, nullableAnnotation string) { + methodTypes := [][]string{ + {"Integer", "Int"}, + {"Long", "Long"}, + {"Float", "Float"}, + {"Double", "Double"}, + {"Boolean", "Boolean"}, + } + + for _, methodType := range methodTypes { + b.WriteIndentedString(1, fmt.Sprintf( + "private static %s%s get%s(%sResultSet rs, int col) throws SQLException {\n", + nullableAnnotation, + methodType[0], + methodType[1], + nonNullAnnotation, + )) + b.WriteIndentedString(2, fmt.Sprintf( + "var colVal = rs.get%s(col); return rs.wasNull() ? null : colVal;\n", + methodType[1], + )) + b.WriteIndentedString(1, "}\n") + } } diff --git a/internal/codegen/models.go b/internal/codegen/models.go new file mode 100644 index 0000000..2066399 --- /dev/null +++ b/internal/codegen/models.go @@ -0,0 +1,72 @@ +package codegen + +import ( + "fmt" + "slices" + "strings" + + "github.com/iancoleman/strcase" + "github.com/tandemdude/sqlc-gen-java/internal/core" +) + +func BuildModelFile(config core.Config, name string, model []core.QueryReturn) (string, []byte, error) { + imports := make([]string, 0) + + var nonNullAnnotation string + if config.NonNullAnnotation != "" { + imports = append(imports, config.NonNullAnnotation) + nonNullAnnotation = "@" + config.NonNullAnnotation[strings.LastIndex(config.NonNullAnnotation, ".")+1:] + " " + } + var nullableAnnotation string + if config.NullableAnnotation != "" { + imports = append(imports, config.NullableAnnotation) + nullableAnnotation = "@" + config.NullableAnnotation[strings.LastIndex(config.NullableAnnotation, ".")+1:] + " " + } + + header := NewIndentStringBuilder(config.IndentChar, config.CharsPerIndentLevel) + header.writeSqlcHeader() + header.WriteString("\n") + header.WriteString("package " + config.Package + ".models;\n") + header.WriteString("\n") + + body := NewIndentStringBuilder(config.IndentChar, config.CharsPerIndentLevel) + body.WriteString("\n") + body.WriteString("public record " + strcase.ToCamel(name) + "(\n") + for i, ret := range model { + imp, jt, err := core.ResolveImportAndType(ret.JavaType.Type) + if err != nil { + return "", nil, err + } + imports = append(imports, imp) + + if ret.JavaType.IsList { + imports = append(imports, "java.util.List") + jt = "List<" + jt + ">" + } + + annotation := nonNullAnnotation + if ret.JavaType.IsNullable { + annotation = nullableAnnotation + } + + body.WriteIndentedString(1, annotation+jt+" "+strcase.ToLowerCamel(ret.Name)) + if i != len(model)-1 { + body.WriteString(",\n") + } + } + body.WriteString("\n") + body.WriteString(") {}\n") + + // sort alphabetically and remove duplicate imports + slices.Sort(imports) + imports = slices.Compact(imports) + for _, imp := range imports { + if imp == "" { + continue + } + + header.WriteString("import " + imp + ";\n") + } + + return fmt.Sprintf("models/%s.java", strcase.ToCamel(name)), []byte(header.String() + body.String()), nil +} diff --git a/internal/codegen/queries.go b/internal/codegen/queries.go index db99c9f..666af8e 100644 --- a/internal/codegen/queries.go +++ b/internal/codegen/queries.go @@ -3,38 +3,71 @@ package codegen import ( "errors" "fmt" - "github.com/iancoleman/strcase" - "github.com/tandemdude/sqlc-gen-java/internal/core" "slices" "strings" + + "github.com/iancoleman/strcase" + "github.com/tandemdude/sqlc-gen-java/internal/core" ) func resultRecordName(q core.Query) string { return strcase.ToCamel(q.MethodName) + "Row" } -func createResultRecord(sb *IndentStringBuilder, indentLevel int, q core.Query) { +func createEmbeddedModel(sb *IndentStringBuilder, prefix, suffix string, identLevel, paramIdx int, r core.QueryReturn, embeddedModels core.EmbeddedModels) int { + modelName := *r.EmbeddedModel + model := embeddedModels[modelName] + + sb.WriteIndentedString(identLevel, prefix+modelName+"(\n") + for i, ret := range model { + sb.WriteIndentedString(identLevel+1, ret.ResultStmt(paramIdx)) + + if i != len(model)-1 { + sb.WriteString(",\n") + paramIdx++ + } + } + sb.WriteString("\n") + sb.WriteIndentedString(identLevel, suffix) + + return paramIdx +} + +func createResultRecord(sb *IndentStringBuilder, indentLevel int, q core.Query, embeddedModels core.EmbeddedModels) { + paramIdx := 1 + if len(q.Returns) == 1 { // set ret to the item directly instead of wrapping it in the result record + if q.Returns[0].EmbeddedModel != nil { + createEmbeddedModel(sb, "var ret = new ", ");\n", indentLevel, paramIdx, q.Returns[0], embeddedModels) + return + } + sb.WriteIndentedString(indentLevel, "var ret = "+q.Returns[0].ResultStmt(1)+";\n") return } recordName := resultRecordName(q) - sb.WriteIndentedString(indentLevel, "var ret = new "+recordName+"(\n") for i, ret := range q.Returns { - sb.WriteIndentedString(indentLevel+1, ret.ResultStmt(i+1)) + // if this return is an embedded model we need to do a lil bit extra + if ret.EmbeddedModel != nil { + paramIdx = createEmbeddedModel(sb, "new ", ")", indentLevel+1, paramIdx, ret, embeddedModels) + } else { + sb.WriteIndentedString(indentLevel+1, ret.ResultStmt(paramIdx)) + } - if i < len(q.Returns)-1 { + if i != len(q.Returns)-1 { sb.WriteString(",\n") } + + paramIdx++ } sb.WriteString("\n") sb.WriteIndentedString(indentLevel, ");\n") } -func completeMethodBody(sb *IndentStringBuilder, q core.Query) { +func completeMethodBody(sb *IndentStringBuilder, q core.Query, embeddedModels core.EmbeddedModels) { sb.WriteString("\n") switch q.Command { @@ -51,15 +84,23 @@ func completeMethodBody(sb *IndentStringBuilder, q core.Query) { sb.WriteIndentedString(2, "if (!results.next()) {\n") sb.WriteIndentedString(3, "return Optional.empty();\n") sb.WriteIndentedString(2, "}\n\n") - createResultRecord(sb, 2, q) + createResultRecord(sb, 2, q, embeddedModels) sb.WriteIndentedString(2, "if (results.next()) {\n") sb.WriteIndentedString(3, "throw new SQLException(\"expected one row in result set, but got many\");\n") sb.WriteIndentedString(2, "}\n\n") sb.WriteIndentedString(2, "return Optional.of(ret);\n") case core.Many: - sb.WriteIndentedString(2, "var retList = new ArrayList<"+resultRecordName(q)+">();\n") + jt := resultRecordName(q) + if len(q.Returns) == 1 { + _, jt, _ = core.ResolveImportAndType(q.Returns[0].JavaType.Type) + if q.Returns[0].EmbeddedModel != nil { + jt = *q.Returns[0].EmbeddedModel + } + } + + sb.WriteIndentedString(2, "var retList = new ArrayList<"+jt+">();\n") sb.WriteIndentedString(2, "while (results.next()) {\n") - createResultRecord(sb, 3, q) + createResultRecord(sb, 3, q, embeddedModels) sb.WriteIndentedString(3, "retList.add(ret);\n") sb.WriteIndentedString(2, "}\n\n") sb.WriteIndentedString(2, "return retList;\n") @@ -78,14 +119,14 @@ func completeMethodBody(sb *IndentStringBuilder, q core.Query) { } } -func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Query) (string, []byte, error) { +func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Query, embeddedModels core.EmbeddedModels) (string, []byte, error) { className := strcase.ToCamel(strings.TrimSuffix(queryFilename, ".sql")) className = strings.TrimSuffix(className, "Query") className = strings.TrimSuffix(className, "Queries") className += "Queries" imports := make([]string, 0) - imports = append(imports, "java.sql.Connection", "java.sql.SQLException") + imports = append(imports, "java.sql.Connection", "java.sql.SQLException", "java.sql.ResultSet", "java.sql.Types") var nonNullAnnotation string if config.NonNullAnnotation != "" { @@ -99,7 +140,7 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q } header := NewIndentStringBuilder(config.IndentChar, config.CharsPerIndentLevel) - writeSqlcHeader(header) + header.writeSqlcHeader() header.WriteString("\n") header.WriteString("package " + config.Package + ";\n") header.WriteString("\n") @@ -113,6 +154,10 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q body.WriteIndentedString(2, "this.conn = conn;\n") body.WriteIndentedString(1, "}\n") + // boilerplate methods to allow for getting null primitive values + body.WriteString("\n") + body.writeQueriesBoilerplate(nonNullAnnotation, nullableAnnotation) + for _, q := range queries { body.WriteString("\n") @@ -129,7 +174,7 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q } body.WriteIndentedString(2, "\"\"\";\n") - // write the output record class - TODO figure out if the output is an entire table and if so use a shared model + // write the output record class var returnType string if len(q.Returns) > 1 { returnType = resultRecordName(q) @@ -137,13 +182,11 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q body.WriteString("\n") body.WriteIndentedString(1, "public record "+returnType+"(\n") for i, ret := range q.Returns { - jt := ret.JavaType.Type - if strings.Contains(jt, ".") { - parts := strings.Split(jt, ".") - - imports = append(imports, jt) - jt = parts[len(parts)-1] + imp, jt, err := core.ResolveImportAndType(ret.JavaType.Type) + if err != nil { + return "", nil, err } + imports = append(imports, imp) if ret.JavaType.IsList { imports = append(imports, "java.util.List", "java.util.Arrays") @@ -151,11 +194,11 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q } annotation := nonNullAnnotation - if ret.JavaType.Nullable { + if ret.JavaType.IsNullable { annotation = nullableAnnotation } - body.WriteIndentedString(2, annotation+jt+" "+ret.Name) + body.WriteIndentedString(2, annotation+jt+" "+strcase.ToLowerCamel(ret.Name)) if i != len(q.Returns)-1 { body.WriteString(",\n") } @@ -166,13 +209,11 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q // the query only outputs a single value, we don't need to wrap it in an xxRow record class ret := q.Returns[0] - jt := ret.JavaType.Type - if strings.Contains(jt, ".") { - parts := strings.Split(jt, ".") - - imports = append(imports, jt) - jt = parts[len(parts)-1] + imp, jt, err := core.ResolveImportAndType(ret.JavaType.Type) + if err != nil { + return "", nil, err } + imports = append(imports, imp) if ret.JavaType.IsList { imports = append(imports, "java.util.List", "java.util.Arrays") @@ -210,13 +251,11 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q body.WriteString("\n") for i, arg := range q.Args { - jt := arg.JavaType.Type - if strings.Contains(jt, ".") { - parts := strings.Split(jt, ".") - - imports = append(imports, jt) - jt = parts[len(parts)-1] + imp, jt, err := core.ResolveImportAndType(arg.JavaType.Type) + if err != nil { + return "", nil, err } + imports = append(imports, imp) if arg.JavaType.IsList { imports = append(imports, "java.util.List") @@ -224,7 +263,7 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q } annotation := nonNullAnnotation - if arg.JavaType.Nullable { + if arg.JavaType.IsNullable { annotation = nullableAnnotation } @@ -241,7 +280,7 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q body.WriteString(") throws SQLException {\n") } - completeMethodBody(methodBody, q) + completeMethodBody(methodBody, q, embeddedModels) body.WriteString(methodBody.String()) body.WriteIndentedString(1, "}\n") } @@ -251,6 +290,10 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q slices.Sort(imports) imports = slices.Compact(imports) for _, imp := range imports { + if imp == "" { + continue + } + header.WriteString("import " + imp + ";\n") } diff --git a/internal/core/models.go b/internal/core/models.go index 77f35e5..f3b46ce 100644 --- a/internal/core/models.go +++ b/internal/core/models.go @@ -37,10 +37,10 @@ func QueryCommandFor(rawCommand string) (QueryCommand, error) { } type JavaType struct { - SqlType string - Type string - IsList bool - Nullable bool + SqlType string + Type string + IsList bool + IsNullable bool } type QueryArg struct { @@ -50,31 +50,52 @@ type QueryArg struct { } // TODO - enum types -var literalBindTypes = []string{"Long", "Short", "String", "Boolean", "Float", "Double", "BigDecimal"} +var literalBindTypes = []string{"Integer", "Long", "Short", "String", "Boolean", "Float", "Double", "BigDecimal"} +var bindTypeToJavaSqlTypeConst = map[string]string{ + "Integer": "INTEGER", + "Long": "BIGINT", + "Short": "SMALLINT", + "Boolean": "BOOLEAN", + "Float": "REAL", + "Double": "DOUBLE", +} func (q QueryArg) BindStmt() string { typeOnly := q.JavaType.Type[strings.LastIndex(q.JavaType.Type, ".")+1:] if q.JavaType.IsList { - if q.JavaType.Nullable { + if q.JavaType.IsNullable { return fmt.Sprintf("stmt.setArray(%d, %s == null ? null : conn.createArrayOf(\"%s\", %s.toArray()));", q.Number, q.Name, q.JavaType.SqlType, q.Name) } return fmt.Sprintf("stmt.setArray(%d, conn.createArrayOf(\"%s\", %s.toArray()));", q.Number, q.JavaType.SqlType, q.Name) } if slices.Contains(literalBindTypes, typeOnly) { - return fmt.Sprintf("stmt.set%s(%d, %s);", typeOnly, q.Number, q.Name) + javaSqlType, ok := bindTypeToJavaSqlTypeConst[typeOnly] + // annoying special case + if typeOnly == "Integer" { + typeOnly = "Int" + } + rawSet := fmt.Sprintf("stmt.set%s(%d, %s);", typeOnly, q.Number, q.Name) + + if !q.JavaType.IsNullable || !ok { + return rawSet + } + + return fmt.Sprintf("%s == null ? stmt.setNull(%d, Types.%s) : %s", q.Name, q.Number, javaSqlType, rawSet) } return fmt.Sprintf("stmt.setObject(%d, %s);", q.Number, q.Name) } type QueryReturn struct { - Name string - JavaType JavaType + Name string + JavaType JavaType + EmbeddedModel *string } func (q QueryReturn) ResultStmt(number int) string { + //_, typeOnly, _ := typeOnly := q.JavaType.Type[strings.LastIndex(q.JavaType.Type, ".")+1:] if q.JavaType.IsList { @@ -82,6 +103,14 @@ func (q QueryReturn) ResultStmt(number int) string { } if slices.Contains(literalBindTypes, typeOnly) { + // annoying special case + if typeOnly == "Integer" { + typeOnly = "Int" + } + + if q.JavaType.IsNullable { + return fmt.Sprintf("get%s(results, %d)", typeOnly, number) + } return fmt.Sprintf("results.get%s(%d)", typeOnly, number) } @@ -99,3 +128,4 @@ type Query struct { } type Queries map[string][]Query +type EmbeddedModels map[string][]QueryReturn diff --git a/internal/core/utils.go b/internal/core/utils.go new file mode 100644 index 0000000..d06d5f5 --- /dev/null +++ b/internal/core/utils.go @@ -0,0 +1,34 @@ +package core + +import ( + "fmt" + "slices" + "strings" + "unicode" + "unicode/utf8" +) + +// ResolveImportAndType extracts the import required, and type representation of the given java type. +func ResolveImportAndType(typ string) (string, string, error) { + if !strings.Contains(typ, ".") { + return "", typ, nil + } + + parts := strings.Split(typ, ".") + capitalIdx := slices.IndexFunc(parts, func(s string) bool { + r, _ := utf8.DecodeRuneInString(s) + return unicode.IsUpper(r) + }) + + if capitalIdx == -1 { + // fatal error, this should never happen + return "", "", fmt.Errorf("failed resolving type and import for %s", typ) + } + + if capitalIdx == 0 { + // special case - nested class in same package, no import required + return "", strings.Join(parts, "."), nil + } + // build the import and the type name from the resolved outer class name + return strings.Join(parts[:capitalIdx+1], "."), strings.Join(parts[capitalIdx:], "."), nil +} diff --git a/internal/core/utils_test.go b/internal/core/utils_test.go new file mode 100644 index 0000000..cc1a6f7 --- /dev/null +++ b/internal/core/utils_test.go @@ -0,0 +1,34 @@ +package core + +import "testing" + +func TestResolveImportAndType(t *testing.T) { + cases := []string{ + "foo.bar.baz.Bork.Qux", + "com.example.MyClass.InnerClass", + "java.util.List", + "SingleClass", + "Nested.Class", + } + expected := [][]string{ + {"foo.bar.baz.Bork", "Bork.Qux"}, + {"com.example.MyClass", "MyClass.InnerClass"}, + {"java.util.List", "List"}, + {"", "SingleClass"}, + {"", "Nested.Class"}, + } + + for i, tc := range cases { + imp, typ, err := ResolveImportAndType(tc) + if err != nil { + t.Fatal(err) + } + + if imp != expected[i][0] { + t.Errorf("case %d: expected '%s', got '%s'", i, expected[i][0], imp) + } + if typ != expected[i][1] { + t.Errorf("case %d: expected '%s', got '%s'", i, expected[i][1], typ) + } + } +} diff --git a/internal/gen.go b/internal/gen.go index 3e61e80..706e7bd 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -3,17 +3,20 @@ package internal import ( "context" "encoding/json" + "errors" "fmt" + "github.com/tandemdude/sqlc-gen-java/internal/inflection" + "regexp" + "slices" + "strconv" + "strings" + "github.com/iancoleman/strcase" "github.com/sqlc-dev/plugin-sdk-go/plugin" "github.com/sqlc-dev/plugin-sdk-go/sdk" "github.com/tandemdude/sqlc-gen-java/internal/codegen" "github.com/tandemdude/sqlc-gen-java/internal/core" "github.com/tandemdude/sqlc-gen-java/internal/sql_types" - "regexp" - "slices" - "strconv" - "strings" ) var ( @@ -42,7 +45,27 @@ func fixQueryPlaceholders(engine, query string) (string, error) { return newQuery, nil } -// TODO - consider sqlc.embed support +func parseQueryReturn(tcf sql_types.TypeConversionFunc, col *plugin.Column) (*core.QueryReturn, error) { + name := strcase.ToCamel(col.Name) + javaType, err := tcf(col.Type) + if err != nil { + return nil, err + } + + if col.ArrayDims > 1 { + return nil, fmt.Errorf("multidimensional arrays are not supported, store JSON instead") + } + + return &core.QueryReturn{ + Name: name, + JavaType: core.JavaType{ + SqlType: sdk.DataType(col.Type), + Type: javaType, + IsList: col.IsArray, + IsNullable: !col.NotNull, + }, + }, nil +} func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) { conf := core.Config{ @@ -65,8 +88,10 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat return nil, fmt.Errorf("engine %q is not supported", req.Settings.Engine) } - // parse the incoming generate request into our Queries type var queries core.Queries = make(map[string][]core.Query) + var embeddedModels core.EmbeddedModels = make(map[string][]core.QueryReturn) + + // parse the incoming generate request into our Queries type for _, query := range req.Queries { if _, ok := queries[query.Filename]; !ok { queries[query.Filename] = make([]core.Query, 0) @@ -77,7 +102,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat return nil, err } - // TODO - check for array types? enum types? other specialness? + // TODO - enum types? other specialness? args := make([]core.QueryArg, 0) for _, arg := range query.Params { javaType, err := typeConversionFunc(arg.Column.Type) @@ -93,34 +118,80 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat Number: int(arg.Number), Name: arg.Column.Name, JavaType: core.JavaType{ - SqlType: sdk.DataType(arg.Column.Type), - Type: javaType, - IsList: arg.Column.ArrayDims > 0, // TODO check this will always be present - Nullable: !arg.Column.NotNull, + SqlType: sdk.DataType(arg.Column.Type), + Type: javaType, + IsList: arg.Column.IsArray, // TODO check this will always be present + IsNullable: !arg.Column.NotNull, }, }) } - // TODO - check for array types? enum types? other specialness? - returns := make([]core.QueryReturn, 0) + // TODO - enum types? other specialness? + var returns []core.QueryReturn for _, ret := range query.Columns { - javaType, err := typeConversionFunc(ret.Type) - if err != nil { - return nil, err + if ret.EmbedTable == nil { + // normal types + qr, err := parseQueryReturn(typeConversionFunc, ret) + if err != nil { + return nil, errors.Join(errors.New("failed to parse query return column"), err) + } + + returns = append(returns, *qr) + continue } - if ret.ArrayDims > 1 { - return nil, fmt.Errorf("multidimensional arrays are not supported, store JSON instead") + // handle embedded types + var table *plugin.Table + + // find the catalog entry for the embedded table + schema := req.Catalog.DefaultSchema + if ret.EmbedTable.Schema != "" { + schema = ret.EmbedTable.Schema + } + + for _, s := range req.Catalog.Schemas { + if s.Name != schema { + continue + } + + for _, t := range s.Tables { + if ret.EmbedTable.Name == t.Rel.Name { + table = t + break + } + } + } + if table == nil { + return nil, fmt.Errorf("unknown embedded table %s.%s", schema, ret.EmbedTable) + } + + // TODO - fix type-writer to only exclude items that aren't part of the package name + modelName := strcase.ToCamel(inflection.Singular(table.Rel.Name)) + // check if we already have an entry for this model + if _, ok := embeddedModels[modelName]; !ok { + var modelParams []core.QueryReturn + for _, c := range table.Columns { + qr, err := parseQueryReturn(typeConversionFunc, c) + if err != nil { + return nil, errors.Join(errors.New("failed to parse query return column"), err) + } + + modelParams = append(modelParams, *qr) + } + + embeddedModels[modelName] = modelParams } returns = append(returns, core.QueryReturn{ - Name: ret.Name, + Name: strcase.ToLowerCamel(modelName), JavaType: core.JavaType{ - SqlType: sdk.DataType(ret.Type), - Type: javaType, - IsList: ret.ArrayDims > 0, // TODO check this will always be present - Nullable: !ret.NotNull, + SqlType: "", + // we don't need to specify package here - models file will be generated in the same location as the queries file + Type: conf.Package + ".models." + modelName, + IsList: false, // TODO - check: this *should* be impossible + IsNullable: false, // TODO - check: empty record should be output instead }, + EmbeddedModel: &modelName, }) } @@ -148,7 +219,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat slices.SortFunc(queries[file], func(a, b core.Query) int { return strings.Compare(a.MethodName, b.MethodName) }) // build the queries file contents - fileName, fileContents, err := codegen.BuildQueriesFile(conf, file, queries[file]) + fileName, fileContents, err := codegen.BuildQueriesFile(conf, file, queries[file], embeddedModels) if err != nil { return nil, err } @@ -158,7 +229,16 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat }) } - // TODO - figure out common output models so we don't duplicate the same model in code 100 times + for modelName, model := range embeddedModels { + fileName, fileContents, err := codegen.BuildModelFile(conf, modelName, model) + if err != nil { + return nil, err + } + outputFiles = append(outputFiles, &plugin.File{ + Name: fileName, + Contents: fileContents, + }) + } return &plugin.GenerateResponse{Files: outputFiles}, nil } diff --git a/internal/inflection/singular.go b/internal/inflection/singular.go new file mode 100644 index 0000000..1d95274 --- /dev/null +++ b/internal/inflection/singular.go @@ -0,0 +1,24 @@ +package inflection + +import ( + "github.com/jinzhu/inflection" + "strings" +) + +func Singular(s string) string { + // Manual fix for incorrect handling of "campus" + // + // https://github.com/kyleconroy/sqlc/issues/430 + // https://github.com/jinzhu/inflection/issues/13 + if strings.ToLower(s) == "campus" { + return s + } + // Manual fix for incorrect handling of "meta" + // + // https://github.com/kyleconroy/sqlc/issues/1217 + // https://github.com/jinzhu/inflection/issues/21 + if strings.ToLower(s) == "meta" { + return s + } + return inflection.Singular(s) +} diff --git a/tests/.gitignore b/tests/.gitignore index 1841315..e0f7127 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -38,9 +38,9 @@ build/ .DS_Store ### Generated Files ### -src/main/java/io/github/tandemdude/sgj/mysql/*.java +src/main/java/io/github/tandemdude/sgj/mysql/**/*.java !src/main/java/io/github/tandemdude/sgj/mysql/package-info.java -src/main/java/io/github/tandemdude/sgj/postgres/*.java +src/main/java/io/github/tandemdude/sgj/postgres/**/*.java !src/main/java/io/github/tandemdude/sgj/postgres/package-info.java -src/main/java/io/github/tandemdude/sgj/sqlite/*.java +src/main/java/io/github/tandemdude/sgj/sqlite/**/*.java !src/main/java/io/github/tandemdude/sgj/sqlite/package-info.java diff --git a/tests/sqlc.yaml b/tests/sqlc.yaml index 1607c6d..d51e657 100644 --- a/tests/sqlc.yaml +++ b/tests/sqlc.yaml @@ -1,9 +1,9 @@ -version: '2' +version: "2" plugins: - name: java wasm: url: file://sqlc-gen-java.wasm - sha256: b81971e01687b14e8732973c8424c72b81b354b95bf526b58b46ab9213c1364e + sha256: ec7a6b66a3c514f8b638d0f274295ca067722fa47c566f48c2d360687af6d1c6 sql: - schema: src/main/resources/postgres/schema.sql queries: src/main/resources/postgres/queries.sql diff --git a/tests/src/main/resources/postgres/queries.sql b/tests/src/main/resources/postgres/queries.sql index 53808bb..eee3b88 100644 --- a/tests/src/main/resources/postgres/queries.sql +++ b/tests/src/main/resources/postgres/queries.sql @@ -1,5 +1,6 @@ -- name: CreateUser :exec -INSERT INTO users(user_id, username, email) VALUES ($1, $2, $3); +INSERT INTO users(user_id, username, email) +VALUES ($1, $2, $3); -- name: GetUser :one SELECT * FROM users WHERE user_id = $1; @@ -10,6 +11,26 @@ SELECT * FROM users WHERE user_id IS NOT NULL; -- name: ListUsers :many SELECT * FROM users; +-- name: CreateToken :one +INSERT INTO tokens(user_id, token, expiry) +VALUES ($1, $2, $3) +RETURNING token_id; + +-- name: GetUserAndToken :one +SELECT sqlc.embed(users), sqlc.embed(tokens) +FROM users +JOIN tokens ON tokens.user_id = users.user_id +WHERE users.user_id = $1; + +-- name: GetEmbeddedUser :one +SELECT sqlc.embed(users) +FROM users +WHERE users.user_id = $1; + +-- name: ListEmbeddedUsers :many +SELECT sqlc.embed(users) +FROM users; + -- name: CreateMessage :one INSERT INTO messages(chat_id, user_id, content, attachments) VALUES ($1, $2, $3, $4) diff --git a/tests/src/test/java/io/github/tandemdude/sgj/postgres/TestQueries.java b/tests/src/test/java/io/github/tandemdude/sgj/postgres/TestQueries.java index cd1a50a..3093994 100644 --- a/tests/src/test/java/io/github/tandemdude/sgj/postgres/TestQueries.java +++ b/tests/src/test/java/io/github/tandemdude/sgj/postgres/TestQueries.java @@ -9,6 +9,7 @@ import java.sql.Connection; import java.sql.DriverManager; import java.sql.SQLException; +import java.time.LocalDateTime; import java.util.List; import java.util.UUID; @@ -48,7 +49,7 @@ public void getUserReturnsPopulatedOptionalRecordFound() throws Exception { var found = q.getUser(uid); assertThat(found).isPresent(); - assertThat(found.get().user_id()).isEqualTo(uid); + assertThat(found.get().userId()).isEqualTo(uid); assertThat(found.get().username()).isEqualTo("foo"); assertThat(found.get().email()).isEqualTo("bar"); } @@ -108,4 +109,25 @@ public void createMessageProcessesInputListCorrectly() throws Exception { assertThat(found.get().attachments()).containsExactly("bar", "baz", "bork"); } } + + @Test + @DisplayName("GetUserAndToken returns embeded objects") + public void getUserAndTokenReturnsEmbededObjects() throws Exception { + try (var conn = getConn()) { + // given + var q = new Queries(conn); + + var userUid = UUID.randomUUID(); + q.createUser(userUid, "foo", "bar"); + q.createToken(userUid, "token", LocalDateTime.now()); + + // when + var userAndToken = q.getUserAndToken(userUid); + + // then + assertThat(userAndToken).isPresent(); + assertThat(userAndToken.get().user().username()).isEqualTo("foo"); + assertThat(userAndToken.get().token().userId()).isEqualTo(userUid); + } + } }