Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 61 additions & 14 deletions internal/codegen/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,21 @@ func resultRecordName(q core.Query) string {
return strcase.ToCamel(q.MethodName) + "Row"
}

func createEmbeddedModel(sb *IndentStringBuilder, prefix, suffix string, identLevel, paramIdx int, r core.QueryReturn, embeddedModels core.EmbeddedModels) int {
func createEmbeddedModel(sb *IndentStringBuilder, prefix, suffix string, identLevel, paramIdx int, r core.QueryReturn, embeddedModels core.EmbeddedModels) (int, []string, error) {
imports := make([]string, 0)

modelName := *r.EmbeddedModel
model := embeddedModels[modelName]

sb.WriteIndentedString(identLevel, prefix+modelName+"(\n")
for i, ret := range model {
sb.WriteIndentedString(identLevel+1, ret.ResultStmt(paramIdx))
stm, imp, err := ret.ResultStmt(paramIdx)
if err != nil {
return 0, nil, nil
}

imports = append(imports, imp)
sb.WriteIndentedString(identLevel+1, stm)

if i != len(model)-1 {
sb.WriteString(",\n")
Expand All @@ -30,31 +38,54 @@ func createEmbeddedModel(sb *IndentStringBuilder, prefix, suffix string, identLe
sb.WriteString("\n")
sb.WriteIndentedString(identLevel, suffix)

return paramIdx
return paramIdx, imports, nil
}

func createResultRecord(sb *IndentStringBuilder, indentLevel int, q core.Query, embeddedModels core.EmbeddedModels) {
func createResultRecord(sb *IndentStringBuilder, indentLevel int, q core.Query, embeddedModels core.EmbeddedModels) ([]string, error) {
paramIdx := 1

if len(q.Returns) == 1 {
// set ret to the item directly instead of wrapping it in the result record
if q.Returns[0].EmbeddedModel != nil {
createEmbeddedModel(sb, "var ret = new ", ");\n", indentLevel, paramIdx, q.Returns[0], embeddedModels)
return
_, imps, err := createEmbeddedModel(sb, "var ret = new ", ");\n", indentLevel, paramIdx, q.Returns[0], embeddedModels)
if err != nil {
return nil, err
}
return imps, err
}

sb.WriteIndentedString(indentLevel, "var ret = "+q.Returns[0].ResultStmt(1)+";\n")
return
stm, imp, err := q.Returns[0].ResultStmt(1)
if err != nil {
return nil, err
}

imports := []string{imp}
sb.WriteIndentedString(indentLevel, "var ret = "+stm+";\n")
return imports, err
}

recordName := resultRecordName(q)
sb.WriteIndentedString(indentLevel, "var ret = new "+recordName+"(\n")
imports := make([]string, 0)

for i, ret := range q.Returns {
// if this return is an embedded model we need to do a lil bit extra
if ret.EmbeddedModel != nil {
paramIdx = createEmbeddedModel(sb, "new ", ")", indentLevel+1, paramIdx, ret, embeddedModels)
newIdx, imps, err := createEmbeddedModel(sb, "new ", ")", indentLevel+1, paramIdx, ret, embeddedModels)
if err != nil {
return nil, err
}

imports = append(imports, imps...)
paramIdx = newIdx
} else {
sb.WriteIndentedString(indentLevel+1, ret.ResultStmt(paramIdx))
stm, imp, err := ret.ResultStmt(paramIdx)
if err != nil {
return nil, err
}

imports = append(imports, imp)
sb.WriteIndentedString(indentLevel+1, stm)
}

if i != len(q.Returns)-1 {
Expand All @@ -65,9 +96,10 @@ func createResultRecord(sb *IndentStringBuilder, indentLevel int, q core.Query,
}
sb.WriteString("\n")
sb.WriteIndentedString(indentLevel, ");\n")
return imports, nil
}

func completeMethodBody(sb *IndentStringBuilder, q core.Query, embeddedModels core.EmbeddedModels) {
func completeMethodBody(sb *IndentStringBuilder, q core.Query, embeddedModels core.EmbeddedModels) ([]string, error) {
sb.WriteString("\n")

switch q.Command {
Expand All @@ -84,11 +116,15 @@ func completeMethodBody(sb *IndentStringBuilder, q core.Query, embeddedModels co
sb.WriteIndentedString(2, "if (!results.next()) {\n")
sb.WriteIndentedString(3, "return Optional.empty();\n")
sb.WriteIndentedString(2, "}\n\n")
createResultRecord(sb, 2, q, embeddedModels)
imports, err := createResultRecord(sb, 2, q, embeddedModels)
if err != nil {
return nil, err
}
sb.WriteIndentedString(2, "if (results.next()) {\n")
sb.WriteIndentedString(3, "throw new SQLException(\"expected one row in result set, but got many\");\n")
sb.WriteIndentedString(2, "}\n\n")
sb.WriteIndentedString(2, "return Optional.of(ret);\n")
return imports, nil
case core.Many:
jt := resultRecordName(q)
if len(q.Returns) == 1 {
Expand All @@ -100,10 +136,14 @@ func completeMethodBody(sb *IndentStringBuilder, q core.Query, embeddedModels co

sb.WriteIndentedString(2, "var retList = new ArrayList<"+jt+">();\n")
sb.WriteIndentedString(2, "while (results.next()) {\n")
createResultRecord(sb, 3, q, embeddedModels)
imports, err := createResultRecord(sb, 3, q, embeddedModels)
if err != nil {
return nil, err
}
sb.WriteIndentedString(3, "retList.add(ret);\n")
sb.WriteIndentedString(2, "}\n\n")
sb.WriteIndentedString(2, "return retList;\n")
return imports, nil
case core.Exec:
break
case core.ExecRows:
Expand All @@ -117,6 +157,8 @@ func completeMethodBody(sb *IndentStringBuilder, q core.Query, embeddedModels co
default:
sb.WriteIndentedString(2, "// TODO\n")
}

return nil, nil
}

func BuildQueriesFile(engine string, config core.Config, queryFilename string, queries []core.Query, embeddedModels core.EmbeddedModels, nullableHelpers core.NullableHelpers) (string, []byte, error) {
Expand Down Expand Up @@ -272,7 +314,12 @@ func BuildQueriesFile(engine string, config core.Config, queryFilename string, q
body.WriteString(") throws SQLException {\n")
}

completeMethodBody(methodBody, q, embeddedModels)
imps, err := completeMethodBody(methodBody, q, embeddedModels)
if err != nil {
return "", nil, err
}
imports = append(imports, imps...)

body.WriteString(methodBody.String())
body.WriteIndentedString(1, "}\n")
}
Expand Down
20 changes: 12 additions & 8 deletions internal/core/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,18 @@ type QueryReturn struct {
EmbeddedModel *string
}

func (q QueryReturn) ResultStmt(number int) string {
func (q QueryReturn) ResultStmt(number int) (string, string, error) {
imp, _, err := ResolveImportAndType(q.JavaType.Type)
if err != nil {
return "", "", err
}
typeOnly := q.JavaType.Type[strings.LastIndex(q.JavaType.Type, ".")+1:]

if q.JavaType.IsList {
if q.JavaType.IsNullable {
return fmt.Sprintf("getList(results, %d, %s[].class)", number, typeOnly)
return fmt.Sprintf("getList(results, %d, %s[].class)", number, typeOnly), imp, nil
}
return fmt.Sprintf("Arrays.asList(%s[].class.cast(results.getArray(%d).getArray()))", typeOnly, number)
return fmt.Sprintf("Arrays.asList(%s[].class.cast(results.getArray(%d).getArray()))", typeOnly, number), imp, nil
}

if slices.Contains(literalBindTypes, typeOnly) {
Expand All @@ -137,19 +141,19 @@ func (q QueryReturn) ResultStmt(number int) string {
}

if q.JavaType.IsNullable && ok {
return fmt.Sprintf("get%s(results, %d)", typeOnly, number)
return fmt.Sprintf("get%s(results, %d)", typeOnly, number), imp, nil
}
return fmt.Sprintf("results.get%s(%d)", typeOnly, number)
return fmt.Sprintf("results.get%s(%d)", typeOnly, number), imp, nil
}

if q.JavaType.IsEnum {
if q.JavaType.IsNullable {
return fmt.Sprintf("Optional.ofNullable(results.getString(%d)).map(%s::fromValue).orElse(null)", number, typeOnly)
return fmt.Sprintf("Optional.ofNullable(results.getString(%d)).map(%s::fromValue).orElse(null)", number, typeOnly), imp, nil
}
return fmt.Sprintf("%s.fromValue(results.getString(%d))", typeOnly, number)
return fmt.Sprintf("%s.fromValue(results.getString(%d))", typeOnly, number), imp, nil
}

return fmt.Sprintf("results.getObject(%d, %s.class)", number, typeOnly)
return fmt.Sprintf("results.getObject(%d, %s.class)", number, typeOnly), imp, nil
}

type Query struct {
Expand Down
Loading