diff --git a/internal/codegen/queries.go b/internal/codegen/queries.go index f3d68f8..52196d6 100644 --- a/internal/codegen/queries.go +++ b/internal/codegen/queries.go @@ -14,13 +14,21 @@ func resultRecordName(q core.Query) string { return strcase.ToCamel(q.MethodName) + "Row" } -func createEmbeddedModel(sb *IndentStringBuilder, prefix, suffix string, identLevel, paramIdx int, r core.QueryReturn, embeddedModels core.EmbeddedModels) int { +func createEmbeddedModel(sb *IndentStringBuilder, prefix, suffix string, identLevel, paramIdx int, r core.QueryReturn, embeddedModels core.EmbeddedModels) (int, []string, error) { + imports := make([]string, 0) + modelName := *r.EmbeddedModel model := embeddedModels[modelName] sb.WriteIndentedString(identLevel, prefix+modelName+"(\n") for i, ret := range model { - sb.WriteIndentedString(identLevel+1, ret.ResultStmt(paramIdx)) + stm, imp, err := ret.ResultStmt(paramIdx) + if err != nil { + return 0, nil, nil + } + + imports = append(imports, imp) + sb.WriteIndentedString(identLevel+1, stm) if i != len(model)-1 { sb.WriteString(",\n") @@ -30,31 +38,54 @@ func createEmbeddedModel(sb *IndentStringBuilder, prefix, suffix string, identLe sb.WriteString("\n") sb.WriteIndentedString(identLevel, suffix) - return paramIdx + return paramIdx, imports, nil } -func createResultRecord(sb *IndentStringBuilder, indentLevel int, q core.Query, embeddedModels core.EmbeddedModels) { +func createResultRecord(sb *IndentStringBuilder, indentLevel int, q core.Query, embeddedModels core.EmbeddedModels) ([]string, error) { paramIdx := 1 if len(q.Returns) == 1 { // set ret to the item directly instead of wrapping it in the result record if q.Returns[0].EmbeddedModel != nil { - createEmbeddedModel(sb, "var ret = new ", ");\n", indentLevel, paramIdx, q.Returns[0], embeddedModels) - return + _, imps, err := createEmbeddedModel(sb, "var ret = new ", ");\n", indentLevel, paramIdx, q.Returns[0], embeddedModels) + if err != nil { + return nil, err + } + return imps, err } - sb.WriteIndentedString(indentLevel, "var ret = "+q.Returns[0].ResultStmt(1)+";\n") - return + stm, imp, err := q.Returns[0].ResultStmt(1) + if err != nil { + return nil, err + } + + imports := []string{imp} + sb.WriteIndentedString(indentLevel, "var ret = "+stm+";\n") + return imports, err } recordName := resultRecordName(q) sb.WriteIndentedString(indentLevel, "var ret = new "+recordName+"(\n") + imports := make([]string, 0) + for i, ret := range q.Returns { // if this return is an embedded model we need to do a lil bit extra if ret.EmbeddedModel != nil { - paramIdx = createEmbeddedModel(sb, "new ", ")", indentLevel+1, paramIdx, ret, embeddedModels) + newIdx, imps, err := createEmbeddedModel(sb, "new ", ")", indentLevel+1, paramIdx, ret, embeddedModels) + if err != nil { + return nil, err + } + + imports = append(imports, imps...) + paramIdx = newIdx } else { - sb.WriteIndentedString(indentLevel+1, ret.ResultStmt(paramIdx)) + stm, imp, err := ret.ResultStmt(paramIdx) + if err != nil { + return nil, err + } + + imports = append(imports, imp) + sb.WriteIndentedString(indentLevel+1, stm) } if i != len(q.Returns)-1 { @@ -65,9 +96,10 @@ func createResultRecord(sb *IndentStringBuilder, indentLevel int, q core.Query, } sb.WriteString("\n") sb.WriteIndentedString(indentLevel, ");\n") + return imports, nil } -func completeMethodBody(sb *IndentStringBuilder, q core.Query, embeddedModels core.EmbeddedModels) { +func completeMethodBody(sb *IndentStringBuilder, q core.Query, embeddedModels core.EmbeddedModels) ([]string, error) { sb.WriteString("\n") switch q.Command { @@ -84,11 +116,15 @@ func completeMethodBody(sb *IndentStringBuilder, q core.Query, embeddedModels co sb.WriteIndentedString(2, "if (!results.next()) {\n") sb.WriteIndentedString(3, "return Optional.empty();\n") sb.WriteIndentedString(2, "}\n\n") - createResultRecord(sb, 2, q, embeddedModels) + imports, err := createResultRecord(sb, 2, q, embeddedModels) + if err != nil { + return nil, err + } sb.WriteIndentedString(2, "if (results.next()) {\n") sb.WriteIndentedString(3, "throw new SQLException(\"expected one row in result set, but got many\");\n") sb.WriteIndentedString(2, "}\n\n") sb.WriteIndentedString(2, "return Optional.of(ret);\n") + return imports, nil case core.Many: jt := resultRecordName(q) if len(q.Returns) == 1 { @@ -100,10 +136,14 @@ func completeMethodBody(sb *IndentStringBuilder, q core.Query, embeddedModels co sb.WriteIndentedString(2, "var retList = new ArrayList<"+jt+">();\n") sb.WriteIndentedString(2, "while (results.next()) {\n") - createResultRecord(sb, 3, q, embeddedModels) + imports, err := createResultRecord(sb, 3, q, embeddedModels) + if err != nil { + return nil, err + } sb.WriteIndentedString(3, "retList.add(ret);\n") sb.WriteIndentedString(2, "}\n\n") sb.WriteIndentedString(2, "return retList;\n") + return imports, nil case core.Exec: break case core.ExecRows: @@ -117,6 +157,8 @@ func completeMethodBody(sb *IndentStringBuilder, q core.Query, embeddedModels co default: sb.WriteIndentedString(2, "// TODO\n") } + + return nil, nil } func BuildQueriesFile(engine string, config core.Config, queryFilename string, queries []core.Query, embeddedModels core.EmbeddedModels, nullableHelpers core.NullableHelpers) (string, []byte, error) { @@ -272,7 +314,12 @@ func BuildQueriesFile(engine string, config core.Config, queryFilename string, q body.WriteString(") throws SQLException {\n") } - completeMethodBody(methodBody, q, embeddedModels) + imps, err := completeMethodBody(methodBody, q, embeddedModels) + if err != nil { + return "", nil, err + } + imports = append(imports, imps...) + body.WriteString(methodBody.String()) body.WriteIndentedString(1, "}\n") } diff --git a/internal/core/models.go b/internal/core/models.go index c07bdc9..9987789 100644 --- a/internal/core/models.go +++ b/internal/core/models.go @@ -119,14 +119,18 @@ type QueryReturn struct { EmbeddedModel *string } -func (q QueryReturn) ResultStmt(number int) string { +func (q QueryReturn) ResultStmt(number int) (string, string, error) { + imp, _, err := ResolveImportAndType(q.JavaType.Type) + if err != nil { + return "", "", err + } typeOnly := q.JavaType.Type[strings.LastIndex(q.JavaType.Type, ".")+1:] if q.JavaType.IsList { if q.JavaType.IsNullable { - return fmt.Sprintf("getList(results, %d, %s[].class)", number, typeOnly) + return fmt.Sprintf("getList(results, %d, %s[].class)", number, typeOnly), imp, nil } - return fmt.Sprintf("Arrays.asList(%s[].class.cast(results.getArray(%d).getArray()))", typeOnly, number) + return fmt.Sprintf("Arrays.asList(%s[].class.cast(results.getArray(%d).getArray()))", typeOnly, number), imp, nil } if slices.Contains(literalBindTypes, typeOnly) { @@ -137,19 +141,19 @@ func (q QueryReturn) ResultStmt(number int) string { } if q.JavaType.IsNullable && ok { - return fmt.Sprintf("get%s(results, %d)", typeOnly, number) + return fmt.Sprintf("get%s(results, %d)", typeOnly, number), imp, nil } - return fmt.Sprintf("results.get%s(%d)", typeOnly, number) + return fmt.Sprintf("results.get%s(%d)", typeOnly, number), imp, nil } if q.JavaType.IsEnum { if q.JavaType.IsNullable { - return fmt.Sprintf("Optional.ofNullable(results.getString(%d)).map(%s::fromValue).orElse(null)", number, typeOnly) + return fmt.Sprintf("Optional.ofNullable(results.getString(%d)).map(%s::fromValue).orElse(null)", number, typeOnly), imp, nil } - return fmt.Sprintf("%s.fromValue(results.getString(%d))", typeOnly, number) + return fmt.Sprintf("%s.fromValue(results.getString(%d))", typeOnly, number), imp, nil } - return fmt.Sprintf("results.getObject(%d, %s.class)", number, typeOnly) + return fmt.Sprintf("results.getObject(%d, %s.class)", number, typeOnly), imp, nil } type Query struct {