diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 29fe462..d051e7c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -9,7 +9,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Setup Go uses: actions/setup-go@v5 @@ -27,7 +27,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Setup Go uses: actions/setup-go@v5 @@ -54,7 +54,7 @@ jobs: run: sqlc generate - name: Setup Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: java-version: '17' distribution: 'adopt' diff --git a/.github/workflows/deploy.yaml b/.github/workflows/deploy.yaml index f4e7fea..0c36377 100644 --- a/.github/workflows/deploy.yaml +++ b/.github/workflows/deploy.yaml @@ -22,7 +22,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Update version id: update-version diff --git a/internal/codegen/common.go b/internal/codegen/common.go index 9656141..1572180 100644 --- a/internal/codegen/common.go +++ b/internal/codegen/common.go @@ -2,12 +2,15 @@ package codegen import ( "fmt" - "os" "strings" "github.com/tandemdude/sqlc-gen-java/internal/core" + "github.com/tandemdude/sqlc-gen-java/poet" ) +var resultSetClass = poet.NewClassName("java.sql", "ResultSet") +var sqlExceptionClass = poet.NewClassName("java.sql", "SQLException") + type IndentStringBuilder struct { strings.Builder @@ -27,29 +30,22 @@ func (b *IndentStringBuilder) WriteIndentedString(level int, s string) int { return count } -func (b *IndentStringBuilder) writeSqlcHeader() { - sqlcVersion := os.Getenv("SQLC_VERSION") - - b.WriteString("// Code generated by sqlc. DO NOT EDIT.\n") - b.WriteString("// versions:\n") - b.WriteString("// sqlc " + sqlcVersion + "\n") - b.WriteString("// sqlc-gen-java " + core.PluginVersion + "\n") -} - type nullableHelper struct { ShouldOutput bool - ReturnType string + ReturnType poet.TypeName ArgType string } -func (b *IndentStringBuilder) writeNullableHelpers(nullableHelpers core.NullableHelpers, nonNullAnnotation, nullableAnnotation string) []string { - imports := make([]string, 0) +func writeNullableHelpers(ctx *poet.Context, nullableHelpers core.NullableHelpers, nonNullAnnotation, nullableAnnotation poet.Annotation) []poet.Method { + // FIXME: Annotations + var methods []poet.Method + methodTypes := []nullableHelper{ - {nullableHelpers.Int, "Integer", "Int"}, - {nullableHelpers.Long, "Long", "Long"}, - {nullableHelpers.Float, "Float", "Float"}, - {nullableHelpers.Double, "Double", "Double"}, - {nullableHelpers.Boolean, "Boolean", "Boolean"}, + {nullableHelpers.Int, poet.IntBoxed, "Int"}, + {nullableHelpers.Long, poet.LongBoxed, "Long"}, + {nullableHelpers.Float, poet.FloatBoxed, "Float"}, + {nullableHelpers.Double, poet.DoubleBoxed, "Double"}, + {nullableHelpers.Boolean, poet.BoolBoxed, "Boolean"}, } for _, methodType := range methodTypes { @@ -57,56 +53,64 @@ func (b *IndentStringBuilder) writeNullableHelpers(nullableHelpers core.Nullable continue } - b.WriteIndentedString(1, fmt.Sprintf( - "private static %s get%s(%s rs, int col) throws SQLException {\n", - core.Annotate(methodType.ReturnType, nullableAnnotation), - methodType.ArgType, - core.Annotate("ResultSet", nonNullAnnotation), - )) - b.WriteIndentedString(2, fmt.Sprintf( - "var colVal = rs.get%s(col); return rs.wasNull() ? null : colVal;\n", - methodType.ArgType, - )) - b.WriteIndentedString(1, "}\n") + method := poet.NewMethodBuilder(fmt.Sprintf("get%s", methodType.ArgType), methodType.ReturnType). + WithParameters(poet.NewMethodParam("rs", resultSetClass), poet.NewMethodParam("col", poet.Int)). + WithThrows(sqlExceptionClass). + WithCode( + poet.NewCodeBuilder(). + WithStatement("var colVar = rs.get$L(col)", methodType.ArgType). + WithStatement("return rs.wasNull() ? null : colVal"). + Build(), + ). + Build() + + methods = append(methods, method) } if nullableHelpers.List { - b.WriteIndentedString(1, fmt.Sprintf( - "private static %s getList(%s rs, int col, Class as) throws SQLException {\n", - core.Annotate("List", nullableAnnotation), - core.Annotate("ResultSet", nonNullAnnotation), - )) - b.WriteIndentedString(2, "var colVal = rs.getArray(col); return colVal == null ? null : Arrays.asList(as.cast(colVal.getArray()));\n") - b.WriteIndentedString(1, "}\n") - - imports = append(imports, "java.util.List") + ctx.Import("java.util.Arrays") + genericParamT := poet.NewGenericParam("T") + + method := poet.NewMethodBuilder("getList", poet.ListOf(genericParamT)). + WithGenericParameters(genericParamT). + WithParameters(poet.NewMethodParam("rs", resultSetClass), poet.NewMethodParam("col", poet.Int)). + WithThrows(sqlExceptionClass). + WithCode( + poet.NewCodeBuilder(). + WithStatement("var colVal = rs.getArray(col)"). + WithStatement("return colVal == null ? null : Arrays.asList(as.cast(colVal.getArray()))"). + Build(), + ). + Build() + + methods = append(methods, method) } - return imports + return methods } -func (b *IndentStringBuilder) writeParameter(javaType core.JavaType, name, nonNullAnnotation, nullableAnnotation string) ([]string, error) { - imp, jt, err := core.ResolveImportAndType(javaType.Type) - if err != nil { - return nil, err - } - imports := []string{imp} - - if javaType.IsList { - imports = append(imports, "java.util.List") - jt = "List<" + jt + ">" - } - - annotation := nonNullAnnotation - if javaType.IsNullable { - annotation = nullableAnnotation - } - - newType, unboxed := core.MaybeUnbox(javaType.Type, javaType.IsNullable) - if !unboxed { - newType = core.Annotate(jt, annotation) - } - - b.WriteIndentedString(2, newType+" "+name) - return imports, nil -} +//func (b *IndentStringBuilder) writeParameter(javaType core.JavaType, name, nonNullAnnotation, nullableAnnotation string) ([]string, error) { +// imp, jt, err := core.ResolveImportAndType(javaType.Type) +// if err != nil { +// return nil, err +// } +// imports := []string{imp} +// +// if javaType.IsList { +// imports = append(imports, "java.util.List") +// jt = "List<" + jt + ">" +// } +// +// annotation := nonNullAnnotation +// if javaType.IsNullable { +// annotation = nullableAnnotation +// } +// +// newType, unboxed := core.MaybeUnbox(javaType.Type, javaType.IsNullable) +// if !unboxed { +// newType = core.Annotate(jt, annotation) +// } +// +// b.WriteIndentedString(2, newType+" "+name) +// return imports, nil +//} diff --git a/internal/codegen/enums.go b/internal/codegen/enums.go index 877cf04..da9bbe5 100644 --- a/internal/codegen/enums.go +++ b/internal/codegen/enums.go @@ -9,6 +9,7 @@ import ( "github.com/iancoleman/strcase" "github.com/tandemdude/sqlc-gen-java/internal/core" + "github.com/tandemdude/sqlc-gen-java/poet" ) var javaInvalidIdentChars = regexp.MustCompile("[^$\\w]") @@ -30,47 +31,56 @@ func enumValueName(value string) string { return name } -func BuildEnumFile(engine string, conf core.Config, qualName string, enum core.Enum, defaultSchema string) (string, []byte, error) { - className := EnumClassName(qualName, defaultSchema) +func BuildEnumFile(engine string, config core.Config, qualName string, enum core.Enum, defaultSchema string) (string, []byte, error) { + ctx := poet.NewContext( + config.Package+".models", + poet.WithIndent(strings.Repeat(config.IndentChar, config.CharsPerIndentLevel)), + ) - sb := IndentStringBuilder{indentChar: conf.IndentChar, charsPerIndentLevel: conf.CharsPerIndentLevel} - sb.writeSqlcHeader() - sb.WriteString("\n") - sb.WriteString("package " + conf.Package + ".enums;\n") - sb.WriteString("\n") - sb.WriteString("import javax.annotation.processing.Generated;\n") - sb.WriteString("\n") - sb.WriteString("@Generated(\"io.github.tandemdude.sqlc-gen-java\")\n") - sb.WriteString("public enum " + className + " {\n") + enumName := EnumClassName(qualName, defaultSchema) + + enumBuilder := poet.NewEnumBuilder(enumName). + WithAnnotation( + poet.NewAnnotationBuilder(generatedClass). + WithMember("value", "$S", "io.github.tandemdude.sqlc-gen-java"). + Build(), + ). + WithModifiers(poet.ModifierPublic) if engine == "mysql" { - sb.WriteIndentedString(1, "BLANK(\"\"),\n") + enumBuilder.WithValue("BLANK", "") } // write other values - for i, value := range enum.Values { - name := enumValueName(value) - sb.WriteIndentedString(1, fmt.Sprintf("%s(\"%s\")", name, value)) - - if i < len(enum.Values)-1 { - sb.WriteString(",\n") - } + for _, value := range enum.Values { + enumBuilder.WithValue(enumValueName(value), value) } - sb.WriteString(";\n\n") - sb.WriteIndentedString(1, "private final String value;\n\n") - sb.WriteIndentedString(1, className+"(final String value) {\n") - sb.WriteIndentedString(2, "this.value = value;\n") - sb.WriteIndentedString(1, "}\n\n") - sb.WriteIndentedString(1, "public String getValue() {\n") - sb.WriteIndentedString(2, "return this.value;") - sb.WriteIndentedString(1, "}\n\n") - sb.WriteIndentedString(1, "public static "+className+" fromValue(final String value) {\n") - sb.WriteIndentedString(2, "for (var v : "+className+".values()) {\n") - sb.WriteIndentedString(3, "if (v.value.equals(value)) return v;\n") - sb.WriteIndentedString(2, "}\n") - sb.WriteIndentedString(2, "throw new IllegalArgumentException(\"No enum constant with value \" + value);\n") - sb.WriteIndentedString(1, "}\n") - sb.WriteString("}\n") - return fmt.Sprintf("enums/%s.java", className), []byte(sb.String()), nil + enumType := poet.NewClassName("", enumName) + + enumBuilder.WithMethods( + poet.NewMethodBuilder("getValue", poet.String). + WithCode( + poet.NewCodeBuilder(). + WithStatement("return this.value"). + Build(), + ). + Build(), + + poet.NewMethodBuilder("fromValue", enumType). + // FIXME: Make value final String (?) + WithParameters(poet.NewMethodParam("value", poet.String)). + WithCode( + poet.NewCodeBuilder(). + WithControlFlow("for (var v : $T.values())", func(cb *poet.CodeBuilder) { + cb.WithRawCode("if (v.value.equals(value)) return v;") + }, enumType). + WithStatement(`throw new IllegalArgumentException("No enum constant with value " + value)`). + Build(), + ). + Build(), + ) + + fileContents := poet.FormatFile(ctx, enumBuilder.Build(), poet.WithFileComment(core.FileHeaderComment)) + return fmt.Sprintf("enums/%s.java", enumName), []byte(fileContents), nil } diff --git a/internal/codegen/models.go b/internal/codegen/models.go index 3b586bd..fda2b44 100644 --- a/internal/codegen/models.go +++ b/internal/codegen/models.go @@ -2,65 +2,53 @@ package codegen import ( "fmt" - "slices" "strings" "github.com/iancoleman/strcase" "github.com/tandemdude/sqlc-gen-java/internal/core" + "github.com/tandemdude/sqlc-gen-java/poet" ) func BuildModelFile(config core.Config, name string, model []core.QueryReturn) (string, []byte, error) { - imports := make([]string, 0) - - var nonNullAnnotation string - if config.NonNullAnnotation != "" { - imports = append(imports, config.NonNullAnnotation) - nonNullAnnotation = "@" + config.NonNullAnnotation[strings.LastIndex(config.NonNullAnnotation, ".")+1:] - } - var nullableAnnotation string - if config.NullableAnnotation != "" { - imports = append(imports, config.NullableAnnotation) - nullableAnnotation = "@" + config.NullableAnnotation[strings.LastIndex(config.NullableAnnotation, ".")+1:] - } - - header := NewIndentStringBuilder(config.IndentChar, config.CharsPerIndentLevel) - header.writeSqlcHeader() - header.WriteString("\n") - header.WriteString("package " + config.Package + ".models;\n") - header.WriteString("\n") - header.WriteString("import javax.annotation.processing.Generated;\n") - header.WriteString("\n") - - body := NewIndentStringBuilder(config.IndentChar, config.CharsPerIndentLevel) - body.WriteString("\n") - body.WriteString("@Generated(\"io.github.tandemdude.sqlc-gen-java\")\n") - body.WriteString("public record " + strcase.ToCamel(name) + "(\n") - for i, ret := range model { - imps, err := body.writeParameter(ret.JavaType, ret.Name, nonNullAnnotation, nullableAnnotation) - if err != nil { - return "", nil, err - } - if imps != nil { - imports = append(imports, imps...) - } - - if i != len(model)-1 { - body.WriteString(",\n") - } - } - body.WriteString("\n") - body.WriteString(") {}\n") - - // sort alphabetically and remove duplicate imports - slices.Sort(imports) - imports = slices.Compact(imports) - for _, imp := range imports { - if imp == "" { - continue - } - - header.WriteString("import " + imp + ";\n") + ctx := poet.NewContext( + config.Package+".models", + poet.WithIndent(strings.Repeat(config.IndentChar, config.CharsPerIndentLevel)), + ) + + //var nonNullAnnotation poet.Annotation + //if config.NonNullAnnotation != "" { + // lastIndex := strings.LastIndex(config.NonNullAnnotation, ".") + // pkg := config.NonNullAnnotation[:lastIndex] + // name := config.NonNullAnnotation[lastIndex+1:] + // + // nonNullAnnotation = poet.NewAnnotationBuilder(poet.NewClassName(pkg, name)).Build() + //} + //var nullableAnnotation poet.Annotation + //if config.NullableAnnotation != "" { + // lastIndex := strings.LastIndex(config.NullableAnnotation, ".") + // pkg := config.NullableAnnotation[:lastIndex] + // name := config.NullableAnnotation[lastIndex+1:] + // + // nullableAnnotation = poet.NewAnnotationBuilder(poet.NewClassName(pkg, name)).Build() + //} + + recordName := strcase.ToCamel(name) + + recordBuilder := poet.NewRecordBuilder(recordName). + WithAnnotation( + poet.NewAnnotationBuilder(generatedClass). + WithMember("value", "$S", "io.github.tandemdude.sqlc-gen-java"). + Build(), + ). + WithModifiers(poet.ModifierPublic) + + for _, ret := range model { + // FIXME: Annotations + // Look at common.go:writeParameter + // , nonNullAnnotation, nullableAnnotation + recordBuilder.WithParameters(poet.NewMethodParam(ret.Name, ret.JavaType.Type)) } - return fmt.Sprintf("models/%s.java", strcase.ToCamel(name)), []byte(header.String() + body.String()), nil + fileContents := poet.FormatFile(ctx, recordBuilder.Build(), poet.WithFileComment(core.FileHeaderComment)) + return fmt.Sprintf("models/%s.java", recordName), []byte(fileContents), nil } diff --git a/internal/codegen/queries.go b/internal/codegen/queries.go index f0d0205..632dfed 100644 --- a/internal/codegen/queries.go +++ b/internal/codegen/queries.go @@ -2,14 +2,16 @@ package codegen import ( "errors" - "fmt" - "slices" "strings" "github.com/iancoleman/strcase" "github.com/tandemdude/sqlc-gen-java/internal/core" + "github.com/tandemdude/sqlc-gen-java/poet" ) +var connectionClass = poet.NewClassName("java.sql", "Connection") +var generatedClass = poet.NewClassName("javax.annotation.processing", "Generated") + func resultRecordName(q core.Query) string { return strcase.ToCamel(q.MethodName) + "Row" } @@ -92,7 +94,7 @@ func completeMethodBody(sb *IndentStringBuilder, q core.Query, embeddedModels co case core.Many: jt := resultRecordName(q) if len(q.Returns) == 1 { - _, jt, _ = core.ResolveImportAndType(q.Returns[0].JavaType.Type) + _, jt, _ = core.ResolveImportAndType(q.Returns[0].JavaType.Type.Name) if q.Returns[0].EmbeddedModel != nil { jt = *q.Returns[0].EmbeddedModel } @@ -120,176 +122,162 @@ func completeMethodBody(sb *IndentStringBuilder, q core.Query, embeddedModels co } func BuildQueriesFile(engine string, config core.Config, queryFilename string, queries []core.Query, embeddedModels core.EmbeddedModels, nullableHelpers core.NullableHelpers) (string, []byte, error) { - className := strcase.ToCamel(strings.TrimSuffix(queryFilename, ".sql")) - className = strings.TrimSuffix(className, "Query") - className = strings.TrimSuffix(className, "Queries") - className += "Queries" - - imports := make([]string, 0) - imports = append(imports, "java.sql.SQLException", "java.sql.ResultSet", "java.util.Arrays", "javax.annotation.processing.Generated") + ctx := poet.NewContext( + config.Package, + poet.WithIndent(strings.Repeat(config.IndentChar, config.CharsPerIndentLevel)), + ) - var nonNullAnnotation string + var nonNullAnnotation poet.Annotation if config.NonNullAnnotation != "" { - imports = append(imports, config.NonNullAnnotation) - nonNullAnnotation = "@" + config.NonNullAnnotation[strings.LastIndex(config.NonNullAnnotation, ".")+1:] + lastIndex := strings.LastIndex(config.NonNullAnnotation, ".") + pkg := config.NonNullAnnotation[:lastIndex] + name := config.NonNullAnnotation[lastIndex+1:] + + nonNullAnnotation = poet.NewAnnotationBuilder(poet.NewClassName(pkg, name)).Build() } - var nullableAnnotation string + var nullableAnnotation poet.Annotation if config.NullableAnnotation != "" { - imports = append(imports, config.NullableAnnotation) - nullableAnnotation = "@" + config.NullableAnnotation[strings.LastIndex(config.NullableAnnotation, ".")+1:] + lastIndex := strings.LastIndex(config.NullableAnnotation, ".") + pkg := config.NullableAnnotation[:lastIndex] + name := config.NullableAnnotation[lastIndex+1:] + + nullableAnnotation = poet.NewAnnotationBuilder(poet.NewClassName(pkg, name)).Build() } - header := NewIndentStringBuilder(config.IndentChar, config.CharsPerIndentLevel) - header.writeSqlcHeader() - header.WriteString("\n") - header.WriteString("package " + config.Package + ";\n") - header.WriteString("\n") - - body := NewIndentStringBuilder(config.IndentChar, config.CharsPerIndentLevel) - body.WriteString("\n") - // Add the class declaration and constructor - body.WriteString("@Generated(\"io.github.tandemdude.sqlc-gen-java\")\n") - body.WriteString("public class " + className + " {\n") - body.WriteIndentedString(1, "private final java.sql.Connection conn;\n\n") - body.WriteIndentedString(1, "public "+className+"(java.sql.Connection conn) {\n") - body.WriteIndentedString(2, "this.conn = conn;\n") - body.WriteIndentedString(1, "}\n") + className := strcase.ToCamel(strings.TrimSuffix(queryFilename, ".sql")) + className = strings.TrimSuffix(className, "Query") + className = strings.TrimSuffix(className, "Queries") + className += "Queries" + + classBuilder := poet.NewClassBuilder(className). + WithAnnotation( + poet.NewAnnotationBuilder(generatedClass). + WithMember("value", "$S", "io.github.tandemdude.sqlc-gen-java"). + Build(), + ). + WithModifiers(poet.ModifierPublic). + WithFields(poet.ClassField{ + Name: "conn", + Type: connectionClass, + Modifiers: []poet.Modifier{poet.ModifierPrivate, poet.ModifierFinal}, + }). + WithConstructor( + poet.NewConstructorBuilder(). + WithParameters(poet.NewMethodParam("conn", connectionClass)). + WithCode( + poet.NewCodeBuilder(). + WithStatement("this.conn = conn"). + Build(), + ). + Build(), + ) + + methods := writeNullableHelpers(ctx, nullableHelpers, nonNullAnnotation, nullableAnnotation) + var methodBuilder *poet.MethodBuilder + var method poet.Method if config.ExposeConnection { - body.WriteString("\n") - body.WriteIndentedString(1, "public java.sql.Connection getConn() {return this.conn;}\n") + method = poet.NewMethodBuilder("getConn", connectionClass). + WithCode(poet.NewCodeBuilder(). + WithStatement("return this.conn"). + Build(), + ). + Build() + + methods = append(methods, method) } - // boilerplate methods to allow for getting null primitive values - body.WriteString("\n") - - imp := body.writeNullableHelpers(nullableHelpers, nonNullAnnotation, nullableAnnotation) - imports = append(imports, imp...) - for _, q := range queries { - body.WriteString("\n") - - // write the static attribute containing the query string - body.WriteIndentedString(1, "private static final String "+q.MethodName+" = \"\"\"\n") - body.WriteIndentedString(2, "-- name: "+q.RawQueryName+" "+q.RawCommand+"\n") - // for each line in the query, ensure it is indented correctly + queryStrBuilder := NewIndentStringBuilder(config.IndentChar, config.CharsPerIndentLevel) + queryStrBuilder.WriteString("\"\"\"\n") + queryStrBuilder.WriteIndentedString(1, "-- name: "+q.RawQueryName+" "+q.RawCommand+"\n") for _, part := range strings.Split(q.Text, "\n") { if part == "" { continue } - body.WriteIndentedString(2, part+"\n") + queryStrBuilder.WriteIndentedString(1, part+"\n") } - body.WriteIndentedString(2, "\"\"\";\n") + queryStrBuilder.WriteIndentedString(1, "\"\"\";") + + classBuilder.WithFields( + poet.NewClassFieldBuilder(q.MethodName, poet.String). + WithModifiers(poet.ModifierPublic, poet.ModifierStatic, poet.ModifierFinal). + WithInitializer(queryStrBuilder.String()). + Build(), + ) // write the output record class - var returnType string + var returnType poet.TypeName if len(q.Returns) > 1 { - returnType = resultRecordName(q) - - body.WriteString("\n") - body.WriteIndentedString(1, "public record "+returnType+"(\n") - for i, ret := range q.Returns { - imps, err := body.writeParameter(ret.JavaType, ret.Name, nonNullAnnotation, nullableAnnotation) - if err != nil { - return "", nil, err - } - if imps != nil { - imports = append(imports, imps...) - } - - if i != len(q.Returns)-1 { - body.WriteString(",\n") - } + recordName := resultRecordName(q) + recordBuilder := poet.NewRecordBuilder(recordName) + + // FIXME: Annotations + // Look at common.go:writeParameter + // , nonNullAnnotation, nullableAnnotation + for _, ret := range q.Returns { + recordBuilder.WithParameters(poet.NewMethodParam(ret.Name, ret.JavaType.Type)) } - body.WriteString("\n") - body.WriteIndentedString(1, ") {}\n") + + classBuilder.WithMembers(recordBuilder.Build()) + + returnType = poet.NewClassName("", recordName) } else if len(q.Returns) == 1 { // the query only outputs a single value, we don't need to wrap it in an xxRow record class ret := q.Returns[0] - imp, jt, err := core.ResolveImportAndType(ret.JavaType.Type) - if err != nil { - return "", nil, err - } - imports = append(imports, imp) - if ret.JavaType.IsList { - imports = append(imports, "java.util.List") - jt = "List<" + jt + ">" + returnType = poet.ListOf(ret.JavaType.Type) + } else { + returnType = ret.JavaType.Type } - - returnType = jt } // figure out what the return type of the method should be switch q.Command { case core.One: - imports = append(imports, "java.util.Optional") - returnType = "Optional<" + returnType + ">" + returnType = poet.OptionalOf(returnType) case core.Many: - imports = append(imports, "java.util.List", "java.util.ArrayList") - returnType = "List<" + returnType + ">" + returnType = poet.ListOf(returnType) case core.Exec: - returnType = "void" + returnType = poet.Void case core.ExecRows: - returnType = "int" + returnType = poet.Int case core.ExecResult: - returnType = "long" + returnType = poet.Long case core.CopyFrom: return "", []byte{}, errors.New("copyFrom is not currently supported") } - methodBody := NewIndentStringBuilder(config.IndentChar, config.CharsPerIndentLevel) + methodBuilder = poet.NewMethodBuilder(q.MethodName, returnType).WithModifiers(poet.ModifierPublic).WithThrows(sqlExceptionClass) + codeBuilder := poet.NewCodeBuilder() + if q.Command == core.ExecResult { - methodBody.WriteIndentedString(2, "var stmt = conn.prepareStatement("+q.MethodName+", java.sql.Statement.RETURN_GENERATED_KEYS);\n") + codeBuilder.WithStatement("var stmt = conn.prepareStatement($L, java.sql.Statement.RETURN_GENERATED_KEYS)", q.MethodName) } else { - methodBody.WriteIndentedString(2, "var stmt = conn.prepareStatement("+q.MethodName+");\n") + codeBuilder.WithStatement("var stmt = conn.prepareStatement($L)", q.MethodName) } - // write the method signature - body.WriteString("\n") - body.WriteIndentedString(1, fmt.Sprintf("public %s %s(", returnType, q.MethodName)) - if len(q.Args) > 0 { - body.WriteString("\n") - - for i, arg := range q.Args { - imps, err := body.writeParameter(arg.JavaType, arg.Name, nonNullAnnotation, nullableAnnotation) - if err != nil { - return "", nil, err - } - if imps != nil { - imports = append(imports, imps...) - } - - if i != len(q.Args)-1 { - body.WriteString(",\n") - } - - methodBody.WriteIndentedString(2, arg.BindStmt(engine)+"\n") - } - body.WriteString("\n") - body.WriteIndentedString(1, ") throws SQLException {\n") - } else { - body.WriteString(") throws SQLException {\n") + for _, arg := range q.Args { + // FIXME: Annotations + // Look at common.go:writeParameter + // , nonNullAnnotation, nullableAnnotation + methodBuilder.WithParameters(poet.NewMethodParam(arg.Name, arg.JavaType.Type)) + // FIXME: Make BindStmt take in the code builder + codeBuilder.WithRawCode(arg.BindStmt(engine)) } - completeMethodBody(methodBody, q, embeddedModels) - body.WriteString(methodBody.String()) - body.WriteIndentedString(1, "}\n") - } - body.WriteString("}\n") - - // sort alphabetically and remove duplicate imports - slices.Sort(imports) - imports = slices.Compact(imports) - for _, imp := range imports { - if imp == "" { - continue - } + // FIXME: finish + //methodBody := NewIndentStringBuilder(config.IndentChar, config.CharsPerIndentLevel) + //completeMethodBody(methodBody, q, embeddedModels) - header.WriteString("import " + imp + ";\n") + method = methodBuilder.WithCode(codeBuilder.Build()).Build() + methods = append(methods, method) } - return className + ".java", []byte(header.String() + body.String()), nil + classBuilder.WithMethods(methods...) + + fileContents := poet.FormatFile(ctx, classBuilder.Build(), poet.WithFileComment(core.FileHeaderComment)) + return className + ".java", []byte(fileContents), nil } diff --git a/internal/core/constants.go b/internal/core/constants.go index dea5de9..3adf191 100644 --- a/internal/core/constants.go +++ b/internal/core/constants.go @@ -5,7 +5,7 @@ import ( "strings" ) -const PluginVersion = "0.0.6" +const PluginVersion = "0.0.7" var FileHeaderComment = strings.Join([]string{ "// Code generated by sqlc. DO NOT EDIT.", diff --git a/internal/core/models.go b/internal/core/models.go index 86e2f5a..06c7b33 100644 --- a/internal/core/models.go +++ b/internal/core/models.go @@ -3,7 +3,8 @@ package core import ( "fmt" "slices" - "strings" + + "github.com/tandemdude/sqlc-gen-java/poet" ) type QueryCommand int @@ -36,9 +37,10 @@ func QueryCommandFor(rawCommand string) (QueryCommand, error) { } } +// FIXME: I think JavaType can be simplified to poet.TypeName type JavaType struct { SqlType string - Type string + Type poet.TypeName IsList bool IsNullable bool IsEnum bool @@ -70,7 +72,7 @@ var typeToJavaSqlTypeConst = map[string]string{ } func (q QueryArg) BindStmt(engine string) string { - typeOnly := q.JavaType.Type[strings.LastIndex(q.JavaType.Type, ".")+1:] + typeOnly := q.JavaType.Type.Name if q.JavaType.IsList { if q.JavaType.IsNullable { @@ -126,7 +128,7 @@ type QueryReturn struct { } func (q QueryReturn) ResultStmt(number int) string { - typeOnly := q.JavaType.Type[strings.LastIndex(q.JavaType.Type, ".")+1:] + typeOnly := q.JavaType.Type.Name if q.JavaType.IsList { if q.JavaType.IsNullable { diff --git a/internal/gen.go b/internal/gen.go index 57dc64b..51459ec 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -17,6 +17,7 @@ import ( "github.com/tandemdude/sqlc-gen-java/internal/core" "github.com/tandemdude/sqlc-gen-java/internal/inflection" "github.com/tandemdude/sqlc-gen-java/internal/sqltypes" + "github.com/tandemdude/sqlc-gen-java/poet" ) var ( @@ -96,7 +97,7 @@ func (gen *JavaGenerator) fixQueryPlaceholders(query string) (string, error) { func (gen *JavaGenerator) parseQueryReturn(col *plugin.Column) (*core.QueryReturn, error) { isEnum := false - strJavaType, err := gen.typeConversionFunc(col.Type) + poetType, err := gen.typeConversionFunc(col.Type) if err != nil { schema := col.Table.Schema if schema == "" { @@ -109,7 +110,7 @@ func (gen *JavaGenerator) parseQueryReturn(col *plugin.Column) (*core.QueryRetur } gen.usedEnums = append(gen.usedEnums, enumQualifiedName) - strJavaType = gen.conf.Package + ".enums." + codegen.EnumClassName(enumQualifiedName, gen.req.Catalog.DefaultSchema) + poetType = poet.NewClassName(gen.conf.Package+".enums", codegen.EnumClassName(enumQualifiedName, gen.req.Catalog.DefaultSchema)) isEnum = true } @@ -119,7 +120,7 @@ func (gen *JavaGenerator) parseQueryReturn(col *plugin.Column) (*core.QueryRetur javaType := core.JavaType{ SqlType: sdk.DataType(col.Type), - Type: strJavaType, + Type: poetType, IsList: col.IsArray, IsNullable: !col.NotNull, IsEnum: isEnum, @@ -128,8 +129,8 @@ func (gen *JavaGenerator) parseQueryReturn(col *plugin.Column) (*core.QueryRetur if javaType.IsNullable { if javaType.IsList { gen.nullableHelpers.List = true - } else { - switch strJavaType { + } else if poetType.Package == "" { + switch poetType.Name { case "Integer": gen.nullableHelpers.Int = true case "Long": @@ -175,9 +176,9 @@ func (gen *JavaGenerator) Run() (*plugin.GenerateResponse, error) { // TODO - enum types? other specialness? args := make([]core.QueryArg, 0) - for _, arg := range query.Params { + for index, arg := range query.Params { isEnum := false - javaType, err := gen.typeConversionFunc(arg.Column.Type) + poetType, err := gen.typeConversionFunc(arg.Column.Type) if err != nil { // check if this is an enum type schema := arg.Column.Table.Schema @@ -191,7 +192,7 @@ func (gen *JavaGenerator) Run() (*plugin.GenerateResponse, error) { } gen.usedEnums = append(gen.usedEnums, enumQualifiedName) - javaType = gen.conf.Package + ".enums." + codegen.EnumClassName(enumQualifiedName, gen.req.Catalog.DefaultSchema) + poetType = poet.NewClassName(gen.conf.Package+".enums", codegen.EnumClassName(enumQualifiedName, gen.req.Catalog.DefaultSchema)) isEnum = true } @@ -199,12 +200,17 @@ func (gen *JavaGenerator) Run() (*plugin.GenerateResponse, error) { return nil, fmt.Errorf("multidimensional arrays are not supported, store JSON instead") } + columnName := arg.Column.Name + if columnName == "" { + columnName = fmt.Sprintf("column%d", index+1) + } + args = append(args, core.QueryArg{ Number: int(arg.Number), - Name: strcase.ToLowerCamel(arg.Column.Name), + Name: strcase.ToLowerCamel(columnName), JavaType: core.JavaType{ SqlType: sdk.DataType(arg.Column.Type), - Type: javaType, + Type: poetType, IsList: arg.Column.IsArray, IsNullable: !arg.Column.NotNull, IsEnum: isEnum, @@ -277,7 +283,7 @@ func (gen *JavaGenerator) Run() (*plugin.GenerateResponse, error) { JavaType: core.JavaType{ SqlType: "", // we don't need to specify package here - models file will be generated in the same location as the queries file - Type: gen.conf.Package + ".models." + modelName, + Type: poet.NewClassName(gen.conf.Package+".models", modelName), IsList: false, // TODO - check: this *should* be impossible IsNullable: false, // TODO - check: empty record should be output instead }, diff --git a/internal/sqltypes/common.go b/internal/sqltypes/common.go index f9acbc8..7c8e181 100644 --- a/internal/sqltypes/common.go +++ b/internal/sqltypes/common.go @@ -1,5 +1,8 @@ package sqltypes -import "github.com/sqlc-dev/plugin-sdk-go/plugin" +import ( + "github.com/sqlc-dev/plugin-sdk-go/plugin" + "github.com/tandemdude/sqlc-gen-java/poet" +) -type TypeConversionFunc func(*plugin.Identifier) (string, error) +type TypeConversionFunc func(*plugin.Identifier) (poet.TypeName, error) diff --git a/internal/sqltypes/mysql.go b/internal/sqltypes/mysql.go index 2c030d6..ab7137e 100644 --- a/internal/sqltypes/mysql.go +++ b/internal/sqltypes/mysql.go @@ -9,39 +9,7 @@ import ( "github.com/sqlc-dev/plugin-sdk-go/sdk" ) -func MysqlTypeToJavaType(identifier *plugin.Identifier) (string, error) { - colType := sdk.DataType(identifier) - - switch colType { - case "varchar", "text", "char", "tinytext", "mediumtext", "longtext": - return "String", nil - case "int", "integer", "smallint", "mediumint", "year": - return "Integer", nil - case "bigint": - return "Long", nil - case "blob", "binary", "varbinary", "tinyblob", "mediumblob", "longblob": - return "byte[]", nil - case "double", "double precision", "real": - return "Double", nil - case "decimal", "dec", "fixed": - return "java.math.BigDecimal", nil - case "date": - return "java.time.LocalDate", nil - case "datetime", "time": - return "java.time.LocalDateTime", nil - // TODO - instant support - look into option for this in pgsql as well - case "timestamp": - return "java.time.OffsetDateTime", nil - case "boolean", "bool", "tinyint": - return "Boolean", nil - case "json": - return "String", nil - default: - return "", fmt.Errorf("datatype '%s' not currently supported", colType) - } -} - -func ConvertMySQLType(identifier *plugin.Identifier) (poet.TypeName, error) { +func MysqlTypeToJavaType(identifier *plugin.Identifier) (poet.TypeName, error) { colType := sdk.DataType(identifier) switch colType { diff --git a/internal/sqltypes/postgresql.go b/internal/sqltypes/postgresql.go index 92b95dd..d7e9d3e 100644 --- a/internal/sqltypes/postgresql.go +++ b/internal/sqltypes/postgresql.go @@ -9,48 +9,7 @@ import ( "github.com/sqlc-dev/plugin-sdk-go/sdk" ) -func PostgresTypeToJavaType(identifier *plugin.Identifier) (string, error) { - colType := sdk.DataType(identifier) - - switch colType { - case "serial", "pg_catalog.serial4", "integer", "int", "int4", "pg_catalog.int4": - return "Integer", nil - case "bigserial", "pg_catalog.serial8", "bigint", "pg_catalog.int8": - return "Long", nil - case "smallserial", "pg_catalog.serial2", "smallint", "pg_catalog.int2": - return "Short", nil - case "float", "double precision", "pg_catalog.float8": - return "Double", nil - case "real", "pg_catalog.float4": - return "Float", nil - case "pg_catalog.numeric": - return "java.math.BigDecimal", nil - case "bool", "pg_catalog.bool": - return "Boolean", nil - case "bytea", "blob", "pg_catalog.bytea": - return "byte[]", nil - case "date": - return "java.time.LocalDate", nil - case "pg_catalog.time", "pg_catalog.timetz": - return "java.time.LocalTime", nil - case "pg_catalog.timestamp", "timestamp": - return "java.time.LocalDateTime", nil - case "pg_catalog.timestamptz", "timestamptz": - return "java.time.OffsetDateTime", nil - case "text", "pg_catalog.varchar", "pg_catalog.bpchar", "string": - return "String", nil - case "uuid": - return "java.util.UUID", nil - // TODO - figure out if these can be supported properly - case "jsonb", "inet": - return "String", nil - default: - // void, any - return "", fmt.Errorf("datatype '%s' not currently supported", colType) - } -} - -func ConvertPostgresType(identifier *plugin.Identifier) (poet.TypeName, error) { +func PostgresTypeToJavaType(identifier *plugin.Identifier) (poet.TypeName, error) { colType := sdk.DataType(identifier) switch colType { diff --git a/poet/annotation.go b/poet/annotation.go new file mode 100644 index 0000000..eb5fae9 --- /dev/null +++ b/poet/annotation.go @@ -0,0 +1,59 @@ +package poet + +import "strings" + +type Annotation struct { + Class TypeName + Members map[string]Code +} + +func (a Annotation) Format(ctx *Context) string { + var sb strings.Builder + + sb.WriteRune('@') + sb.WriteString(a.Class.Format(ctx)) + + // FIXME: Add support for embedded annotations + // .addAnnotation(AnnotationSpec.builder(HeaderList.class) + // .addMember("value", "$L", AnnotationSpec.builder(Header.class) + // .addMember("name", "$S", "Accept") + // .addMember("value", "$S", "application/json; charset=utf-8") + // .build()) + // .addMember("value", "$L", AnnotationSpec.builder(Header.class) + // .addMember("name", "$S", "User-Agent") + // .addMember("value", "$S", "Square Cash") + // .build()) + // .build()) + if value, ok := a.Members["value"]; len(a.Members) == 1 && ok { + sb.WriteRune('(') + sb.WriteString(value.Format(ctx)) + sb.WriteRune(')') + } else if len(a.Members) > 0 { + sb.WriteString("(\n") + for name, code := range a.Members { + sb.WriteString(name) + sb.WriteString(" = ") + sb.WriteString(code.Format(ctx)) + } + sb.WriteString("\n)") + } + + return sb.String() +} + +type AnnotationBuilder struct { + annotation Annotation +} + +func NewAnnotationBuilder(class TypeName) *AnnotationBuilder { + return &AnnotationBuilder{annotation: Annotation{Class: class, Members: make(map[string]Code)}} +} + +func (b *AnnotationBuilder) WithMember(name string, value string, args ...any) *AnnotationBuilder { + b.annotation.Members[name] = Code{RawCode: value, Arguments: args} + return b +} + +func (b *AnnotationBuilder) Build() Annotation { + return b.annotation +} diff --git a/poet/code.go b/poet/code.go index b150e20..0be3e93 100644 --- a/poet/code.go +++ b/poet/code.go @@ -19,6 +19,7 @@ type Code struct { IsFlow bool IsTryCatch bool IsIfElse bool + IsStmt bool Arguments []any Statements []Code @@ -121,7 +122,7 @@ func (c *Code) Format(ctx *Context) string { if c.RawCode != "" && !c.IsFlow { // Simple statement sb.WriteString(formatRawCode(ctx, c.RawCode, c.Arguments)) - if !strings.HasSuffix(c.RawCode, ";") { + if c.IsStmt && !strings.HasSuffix(c.RawCode, ";") { sb.WriteRune(';') } @@ -143,7 +144,7 @@ func NewCodeBuilder() *CodeBuilder { } func (b *CodeBuilder) WithStatement(stmt string, args ...any) *CodeBuilder { - b.code.Statements = append(b.code.Statements, Code{RawCode: stmt, Arguments: args}) + b.code.Statements = append(b.code.Statements, Code{RawCode: stmt, Arguments: args, IsStmt: true}) return b } diff --git a/poet/method.go b/poet/method.go index 30b9997..7ff5137 100644 --- a/poet/method.go +++ b/poet/method.go @@ -35,8 +35,11 @@ type Method struct { func (m Method) Format(ctx *Context) string { var sb strings.Builder - sb.WriteString(formatModifiers(m.Modifiers)) - sb.WriteString(" ") + if len(m.Modifiers) > 0 { + sb.WriteString(formatModifiers(m.Modifiers)) + sb.WriteString(" ") + } + writeGenericParamList(ctx, &sb, m.GenericParameters, true) if !m.isConstructor { sb.WriteString(m.ReturnType.Format(ctx, ExcludeConstraints)) @@ -116,7 +119,6 @@ func (b *MethodBuilder) WithCode(code Code) *MethodBuilder { } func (b *MethodBuilder) Build() Method { - b.method.Modifiers = maybeSetPackagePrivate(b.method.Modifiers) return b.method } @@ -153,6 +155,5 @@ func (b *ConstructorBuilder) WithCode(code Code) *ConstructorBuilder { } func (b *ConstructorBuilder) Build() Constructor { - b.constructor.Modifiers = maybeSetPackagePrivate(b.constructor.Modifiers) return b.constructor } diff --git a/poet/modifier.go b/poet/modifier.go index dcab351..e3462f6 100644 --- a/poet/modifier.go +++ b/poet/modifier.go @@ -8,7 +8,6 @@ type Modifier int const ( ModifierPrivate Modifier = iota - ModifierPackagePrivate ModifierProtected ModifierPublic ModifierAbstract @@ -16,19 +15,10 @@ const ( ModifierFinal ) -var accessModifiers = []Modifier{ - ModifierPrivate, - ModifierPackagePrivate, - ModifierProtected, - ModifierPublic, -} - func formatModifier(modifier Modifier) string { switch modifier { case ModifierPrivate: return "private" - case ModifierPackagePrivate: - return "" case ModifierProtected: return "protected" case ModifierPublic: diff --git a/poet/name.go b/poet/name.go index e9f5bea..9d5f433 100644 --- a/poet/name.go +++ b/poet/name.go @@ -4,8 +4,6 @@ import ( "strings" ) -// TODO - annotation support - type TypeName struct { Package string Name string @@ -13,6 +11,9 @@ type TypeName struct { IsBuiltin bool IsArray bool + // FIXME: We need builders for this + Annotations []Annotation + IsParameterized bool Parameters []TypeName @@ -20,8 +21,8 @@ type TypeName struct { Extends []TypeName } -func NewClassName(pkg, name string) TypeName { - return TypeName{Package: pkg, Name: name} +func NewClassName(pkg, name string, annotations ...Annotation) TypeName { + return TypeName{Package: pkg, Name: name, Annotations: annotations} } func (t TypeName) Array() TypeName { @@ -47,37 +48,6 @@ func NewGenericParam(name string, extends ...TypeName) TypeName { } } -func (t TypeName) Equals(other TypeName) bool { - if t.Package != other.Package || - t.Name != other.Name || - t.IsParameterized != other.IsParameterized || - t.IsGeneric != other.IsGeneric || - t.IsArray != other.IsArray { - return false - } - - if len(t.Parameters) != len(other.Parameters) { - return false - } - for i := range t.Parameters { - if !t.Parameters[i].Equals(other.Parameters[i]) { - return false - } - } - - if len(t.Extends) != len(other.Extends) { - return false - } - // TODO - is order actually important? - for i := range t.Extends { - if !t.Extends[i].Equals(other.Extends[i]) { - return false - } - } - - return true -} - type FormatOption int var ( @@ -98,22 +68,25 @@ func (t TypeName) Format(ctx *Context, options ...FormatOption) string { var bld strings.Builder - var typename string if t.Package != "" { // check if we need to use the fully qualified type name due to a collision existing, ok := ctx.Types[t.Name] - if (ok && !t.Equals(existing)) || t.Name == ctx.CurrentTypeName { + if (ok && !(existing.Package == t.Package && existing.Name == t.Name)) || t.Name == ctx.CurrentTypeName { bld.WriteString(t.Package) bld.WriteString(".") - typename = t.Package + "." + t.Name } else { if !t.IsBuiltin { - ctx.Import(t.Package) + ctx.Import(t.Package + "." + t.Name) + ctx.Types[t.Name] = t } - typename = t.Name } } + for _, annotation := range t.Annotations { + bld.WriteString(annotation.Format(ctx)) + bld.WriteString(" ") + } + bld.WriteString(t.Name) if t.IsGeneric && len(t.Extends) > 0 && !opts.has(ExcludeConstraints) { bld.WriteString(" extends ") @@ -139,11 +112,6 @@ func (t TypeName) Format(ctx *Context, options ...FormatOption) string { bld.WriteString("[]") } - // generic type names are not necessarily unique within a file - if !t.IsGeneric { - ctx.Types[typename] = t - } // ctx.GenericTypes = append(ctx.GenericTypes, t) - return bld.String() } @@ -184,7 +152,8 @@ func newTwoParameterType(pkg, name string) func(TypeName, TypeName) TypeName { } var ( - ListOf = newSingleParameterType("java.util", "List") - MapOf = newTwoParameterType("java.util", "Map") - SetOf = newSingleParameterType("java.util", "Set") + ListOf = newSingleParameterType("java.util", "List") + MapOf = newTwoParameterType("java.util", "Map") + SetOf = newSingleParameterType("java.util", "Set") + OptionalOf = newSingleParameterType("java.util", "Optional") ) diff --git a/poet/type.go b/poet/type.go index 0ba1992..d87a095 100644 --- a/poet/type.go +++ b/poet/type.go @@ -5,22 +5,67 @@ import ( "strings" ) -// TODO - annotation support - type ClassField struct { - Name string - Type TypeName - Modifiers []Modifier + Name string + Type TypeName + Initializer Code + Modifiers []Modifier + HasInitializer bool +} + +func (c ClassField) Format(ctx *Context) string { + if !c.HasInitializer { + return fmt.Sprintf( + "%s %s %s;", + formatModifiers(c.Modifiers), + c.Type.Format(ctx, ExcludeConstraints), + c.Name, + ) + } + + return fmt.Sprintf( + "%s %s %s = %s", + formatModifiers(c.Modifiers), + c.Type.Format(ctx, ExcludeConstraints), + c.Name, + c.Initializer.Format(ctx), + ) +} + +type ClassFieldBuilder struct { + classField ClassField +} + +func NewClassFieldBuilder(name string, class TypeName) *ClassFieldBuilder { + return &ClassFieldBuilder{classField: ClassField{Name: name, Type: class}} +} + +func (b *ClassFieldBuilder) WithModifiers(modifiers ...Modifier) *ClassFieldBuilder { + b.classField.Modifiers = append(b.classField.Modifiers, modifiers...) + return b +} + +func (b *ClassFieldBuilder) WithInitializer(code string, args ...any) *ClassFieldBuilder { + b.classField.HasInitializer = true + b.classField.Initializer = Code{RawCode: code, Arguments: args} + return b +} + +func (b *ClassFieldBuilder) Build() ClassField { + return b.classField } type Class struct { Name string + Annotations []Annotation Modifiers []Modifier GenericParameters []TypeName Constructor *Constructor Fields []ClassField Methods []Method + Members []formattable + // FIXME: Idk if formattable is ideal here } func (c Class) name() string { @@ -30,6 +75,11 @@ func (c Class) name() string { func (c Class) Format(ctx *Context) string { var sb strings.Builder + for _, annotation := range c.Annotations { + sb.WriteString(annotation.Format(ctx)) + sb.WriteRune('\n') + } + sb.WriteString(formatModifiers(c.Modifiers)) if sb.Len() > 0 { sb.WriteString(" ") @@ -40,17 +90,14 @@ func (c Class) Format(ctx *Context) string { writeGenericParamList(ctx, &sb, c.GenericParameters, false) sb.WriteString(" {\n") - for i, field := range c.Fields { - sb.WriteString(ctx.indent(fmt.Sprintf( - "%s %s %s;\n", - formatModifiers(field.Modifiers), - field.Type.Format(ctx, ExcludeConstraints), - field.Name, - ))) + for _, field := range c.Fields { + sb.WriteString(ctx.indent(field.Format(ctx))) + sb.WriteString("\n\n") + } - if i == len(c.Fields)-1 { - sb.WriteString("\n") - } + for _, method := range c.Members { + sb.WriteString(ctx.indent(method.Format(ctx))) + sb.WriteString("\n\n") } if c.Constructor != nil { @@ -83,6 +130,11 @@ func NewClassBuilder(name string) *ClassBuilder { return &ClassBuilder{class: Class{Name: name}} } +func (c *ClassBuilder) WithAnnotation(annotations ...Annotation) *ClassBuilder { + c.class.Annotations = append(c.class.Annotations, annotations...) + return c +} + func (c *ClassBuilder) WithModifiers(modifiers ...Modifier) *ClassBuilder { c.class.Modifiers = appendModifiers(c.class.Modifiers, modifiers) return c @@ -108,8 +160,12 @@ func (c *ClassBuilder) WithMethods(methods ...Method) *ClassBuilder { return c } +func (c *ClassBuilder) WithMembers(members ...formattable) *ClassBuilder { + c.class.Members = append(c.class.Members, members...) + return c +} + func (c *ClassBuilder) Build() Class { - c.class.Modifiers = maybeSetPackagePrivate(c.class.Modifiers) return c.class } @@ -126,9 +182,10 @@ func NewEnumValue(name string, value string) EnumValue { type Enum struct { Name string - Modifiers []Modifier - Values []EnumValue - Methods []Method + Modifiers []Modifier + Values []EnumValue + Methods []Method + Annotations []Annotation } func (e Enum) name() string { @@ -138,6 +195,11 @@ func (e Enum) name() string { func (e Enum) Format(ctx *Context) string { var sb strings.Builder + for _, annotation := range e.Annotations { + sb.WriteString(annotation.Format(ctx)) + sb.WriteRune('\n') + } + sb.WriteString(formatModifiers(e.Modifiers)) if sb.Len() > 0 { sb.WriteString(" ") @@ -186,6 +248,11 @@ func NewEnumBuilder(name string) *EnumBuilder { return &EnumBuilder{enum: Enum{Name: name}} } +func (b *EnumBuilder) WithAnnotation(annotations ...Annotation) *EnumBuilder { + b.enum.Annotations = append(b.enum.Annotations, annotations...) + return b +} + func (b *EnumBuilder) WithModifiers(modifiers ...Modifier) *EnumBuilder { b.enum.Modifiers = appendModifiers(b.enum.Modifiers, modifiers) return b @@ -207,16 +274,16 @@ func (b *EnumBuilder) WithMethods(methods ...Method) *EnumBuilder { } func (b *EnumBuilder) Build() Enum { - b.enum.Modifiers = maybeSetPackagePrivate(b.enum.Modifiers) return b.enum } type Record struct { Name string - Modifiers []Modifier - Parameters []MethodParameter - Methods []Method + Annotations []Annotation + Modifiers []Modifier + Parameters []MethodParameter + Methods []Method } func (r Record) name() string { @@ -226,6 +293,11 @@ func (r Record) name() string { func (r Record) Format(ctx *Context) string { var sb strings.Builder + for _, annotation := range r.Annotations { + sb.WriteString(annotation.Format(ctx)) + sb.WriteRune('\n') + } + sb.WriteString(formatModifiers(r.Modifiers)) if sb.Len() > 0 { sb.WriteString(" ") @@ -270,6 +342,11 @@ func NewRecordBuilder(name string) *RecordBuilder { return &RecordBuilder{record: Record{Name: name}} } +func (b *RecordBuilder) WithAnnotation(annotations ...Annotation) *RecordBuilder { + b.record.Annotations = append(b.record.Annotations, annotations...) + return b +} + func (b *RecordBuilder) WithModifiers(modifiers ...Modifier) *RecordBuilder { b.record.Modifiers = appendModifiers(b.record.Modifiers, modifiers) return b @@ -286,6 +363,5 @@ func (b *RecordBuilder) WithMethods(methods ...Method) *RecordBuilder { } func (b *RecordBuilder) Build() Record { - b.record.Modifiers = maybeSetPackagePrivate(b.record.Modifiers) return b.record } diff --git a/poet/utils.go b/poet/utils.go index cadb040..b47e45e 100644 --- a/poet/utils.go +++ b/poet/utils.go @@ -11,16 +11,6 @@ func appendModifiers(initial []Modifier, new []Modifier) []Modifier { return slices.Compact(initial) } -func maybeSetPackagePrivate(modifiers []Modifier) []Modifier { - if !slices.ContainsFunc(modifiers, func(m Modifier) bool { - return slices.Contains(accessModifiers, m) - }) { - modifiers = append(modifiers, ModifierPackagePrivate) - slices.Sort(modifiers) - } - return modifiers -} - func writeGenericParamList(ctx *Context, sb *strings.Builder, params []TypeName, includeTrailingSpace bool) { if len(params) == 0 { return diff --git a/tests/pom.xml b/tests/pom.xml index 0f3747d..8397f08 100644 --- a/tests/pom.xml +++ b/tests/pom.xml @@ -42,19 +42,19 @@ com.mysql mysql-connector-j - 9.3.0 + 9.4.0 org.junit.jupiter junit-jupiter - 5.13.3 + 5.13.4 test org.assertj assertj-core - 3.27.3 + 3.27.4 diff --git a/tests/sqlc.yaml b/tests/sqlc.yaml index e045e6c..a4c3a2f 100644 --- a/tests/sqlc.yaml +++ b/tests/sqlc.yaml @@ -1,9 +1,9 @@ -version: "2" +version: '2' plugins: - name: java wasm: url: file://sqlc-gen-java.wasm - sha256: 1e30725cfd253e18486ebff9ba7eea481bc2680d946ecae85d308de39485b921 + sha256: b1969953842ec52e35eec9242176a984675f8c6b60807c8183527cf908680d14 sql: - schema: src/main/resources/postgres/schema.sql queries: src/main/resources/postgres/queries.sql