diff --git a/README.md b/README.md
index d9905b0..c3488e2 100644
--- a/README.md
+++ b/README.md
@@ -62,12 +62,10 @@ You should ensure that the `sha256` value in your `sqlc.yaml` is correct for thi
## Planned Features
-- `MySQL` support
- `SQLite` support
- Improved parameter naming
**Tentative:**
- r2dbc support
-- Support for PostgreSQL enum types
- copyfrom support where possible [ref](https://www.baeldung.com/jdbc-batch-processing)
diff --git a/internal/codegen/enums.go b/internal/codegen/enums.go
new file mode 100644
index 0000000..9846212
--- /dev/null
+++ b/internal/codegen/enums.go
@@ -0,0 +1,75 @@
+package codegen
+
+import (
+ "fmt"
+ "github.com/iancoleman/strcase"
+ "github.com/tandemdude/sqlc-gen-java/internal/core"
+ "regexp"
+ "strings"
+ "unicode"
+ "unicode/utf8"
+)
+
+var javaInvalidIdentChars = regexp.MustCompile("[^$\\w]")
+
+func EnumClassName(qualifiedName, defaultSchema string) string {
+ return strcase.ToCamel(strings.TrimPrefix(qualifiedName, defaultSchema+"."))
+}
+
+func enumValueName(value string) string {
+ rep := strings.NewReplacer("-", "_", ":", "_", "/", "_", ".", "_")
+ name := rep.Replace(value)
+ name = strings.ToUpper(name)
+ name = javaInvalidIdentChars.ReplaceAllString(name, "")
+
+ r, _ := utf8.DecodeRuneInString(name)
+ if unicode.IsDigit(r) {
+ name = "_" + name
+ }
+ return name
+}
+
+func BuildEnumFile(engine string, conf core.Config, qualName string, enum core.Enum, defaultSchema string) (string, []byte, error) {
+ className := EnumClassName(qualName, defaultSchema)
+
+ sb := IndentStringBuilder{indentChar: conf.IndentChar, charsPerIndentLevel: conf.CharsPerIndentLevel}
+ sb.writeSqlcHeader()
+ sb.WriteString("\n")
+ sb.WriteString("package " + conf.Package + ".enums;\n")
+ sb.WriteString("\n")
+ sb.WriteString("import javax.annotation.processing.Generated;\n")
+ sb.WriteString("\n")
+ sb.WriteString("@Generated(\"io.github.tandemdude.sqlc-gen-java\")\n")
+ sb.WriteString("public enum " + className + " {\n")
+
+ if engine == "mysql" {
+ sb.WriteIndentedString(1, "BLANK(\"\"),\n")
+ }
+
+ // write other values
+ for i, value := range enum.Values {
+ name := enumValueName(value)
+ sb.WriteIndentedString(1, fmt.Sprintf("%s(\"%s\")", name, value))
+
+ if i < len(enum.Values)-1 {
+ sb.WriteString(",\n")
+ }
+ }
+ sb.WriteString(";\n\n")
+ sb.WriteIndentedString(1, "private final String value;\n\n")
+ sb.WriteIndentedString(1, className+"(final String value) {\n")
+ sb.WriteIndentedString(2, "this.value = value;\n")
+ sb.WriteIndentedString(1, "}\n\n")
+ sb.WriteIndentedString(1, "public String getValue() {\n")
+ sb.WriteIndentedString(2, "return this.value;")
+ sb.WriteIndentedString(1, "}\n\n")
+ sb.WriteIndentedString(1, "public static "+className+" fromValue(final String value) {\n")
+ sb.WriteIndentedString(2, "for (var v : "+className+".values()) {\n")
+ sb.WriteIndentedString(3, "if (v.value.equals(value)) return v;\n")
+ sb.WriteIndentedString(2, "}\n")
+ sb.WriteIndentedString(2, "throw new IllegalArgumentException(\"No enum constant with value \" + value);\n")
+ sb.WriteIndentedString(1, "}\n")
+ sb.WriteString("}\n")
+
+ return fmt.Sprintf("enums/%s.java", className), []byte(sb.String()), nil
+}
diff --git a/internal/codegen/models.go b/internal/codegen/models.go
index 2940020..3b586bd 100644
--- a/internal/codegen/models.go
+++ b/internal/codegen/models.go
@@ -28,9 +28,12 @@ func BuildModelFile(config core.Config, name string, model []core.QueryReturn) (
header.WriteString("\n")
header.WriteString("package " + config.Package + ".models;\n")
header.WriteString("\n")
+ header.WriteString("import javax.annotation.processing.Generated;\n")
+ header.WriteString("\n")
body := NewIndentStringBuilder(config.IndentChar, config.CharsPerIndentLevel)
body.WriteString("\n")
+ body.WriteString("@Generated(\"io.github.tandemdude.sqlc-gen-java\")\n")
body.WriteString("public record " + strcase.ToCamel(name) + "(\n")
for i, ret := range model {
imps, err := body.writeParameter(ret.JavaType, ret.Name, nonNullAnnotation, nullableAnnotation)
diff --git a/internal/codegen/queries.go b/internal/codegen/queries.go
index 02470f1..f3d68f8 100644
--- a/internal/codegen/queries.go
+++ b/internal/codegen/queries.go
@@ -119,14 +119,14 @@ func completeMethodBody(sb *IndentStringBuilder, q core.Query, embeddedModels co
}
}
-func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Query, embeddedModels core.EmbeddedModels, nullableHelpers core.NullableHelpers) (string, []byte, error) {
+func BuildQueriesFile(engine string, config core.Config, queryFilename string, queries []core.Query, embeddedModels core.EmbeddedModels, nullableHelpers core.NullableHelpers) (string, []byte, error) {
className := strcase.ToCamel(strings.TrimSuffix(queryFilename, ".sql"))
className = strings.TrimSuffix(className, "Query")
className = strings.TrimSuffix(className, "Queries")
className += "Queries"
imports := make([]string, 0)
- imports = append(imports, "java.sql.SQLException", "java.sql.ResultSet", "java.util.Arrays")
+ imports = append(imports, "java.sql.SQLException", "java.sql.ResultSet", "java.util.Arrays", "javax.annotation.processing.Generated")
var nonNullAnnotation string
if config.NonNullAnnotation != "" {
@@ -148,6 +148,7 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q
body := NewIndentStringBuilder(config.IndentChar, config.CharsPerIndentLevel)
body.WriteString("\n")
// Add the class declaration and constructor
+ body.WriteString("@Generated(\"io.github.tandemdude.sqlc-gen-java\")\n")
body.WriteString("public class " + className + " {\n")
body.WriteIndentedString(1, "private final java.sql.Connection conn;\n\n")
body.WriteIndentedString(1, "public "+className+"(java.sql.Connection conn) {\n")
@@ -238,7 +239,11 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q
}
methodBody := NewIndentStringBuilder(config.IndentChar, config.CharsPerIndentLevel)
- methodBody.WriteIndentedString(2, "var stmt = conn.prepareStatement("+q.MethodName+");\n")
+ if q.Command == core.ExecResult {
+ methodBody.WriteIndentedString(2, "var stmt = conn.prepareStatement("+q.MethodName+", java.sql.Statement.RETURN_GENERATED_KEYS);\n")
+ } else {
+ methodBody.WriteIndentedString(2, "var stmt = conn.prepareStatement("+q.MethodName+");\n")
+ }
// write the method signature
body.WriteString("\n")
@@ -259,7 +264,7 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q
body.WriteString(",\n")
}
- methodBody.WriteIndentedString(2, arg.BindStmt()+"\n")
+ methodBody.WriteIndentedString(2, arg.BindStmt(engine)+"\n")
}
body.WriteString("\n")
body.WriteIndentedString(1, ") throws SQLException {\n")
diff --git a/internal/core/models.go b/internal/core/models.go
index 1aacbdc..c07bdc9 100644
--- a/internal/core/models.go
+++ b/internal/core/models.go
@@ -41,6 +41,7 @@ type JavaType struct {
Type string
IsList bool
IsNullable bool
+ IsEnum bool
}
type QueryArg struct {
@@ -68,7 +69,7 @@ var typeToJavaSqlTypeConst = map[string]string{
"Double": "DOUBLE",
}
-func (q QueryArg) BindStmt() string {
+func (q QueryArg) BindStmt(engine string) string {
typeOnly := q.JavaType.Type[strings.LastIndex(q.JavaType.Type, ".")+1:]
if q.JavaType.IsList {
@@ -94,6 +95,21 @@ func (q QueryArg) BindStmt() string {
return fmt.Sprintf("%s == null ? stmt.setNull(%d, java.sql.Types.%s) : %s", q.Name, q.Number, javaSqlType, rawSet)
}
+ if q.JavaType.IsEnum {
+ // postgres doesn't like it if you setString an enum directly unfortunately
+ if engine == "postgresql" {
+ if q.JavaType.IsNullable {
+ return fmt.Sprintf("stmt.setObject(%d, %s == null ? null : %s.getValue(), java.sql.Types.OTHER);", q.Number, q.Name, q.Name)
+ }
+ return fmt.Sprintf("stmt.setObject(%d, %s.getValue(), java.sql.Types.OTHER);", q.Number, q.Name)
+ }
+
+ if q.JavaType.IsNullable {
+ return fmt.Sprintf("stmt.setString(%d, %s == null ? null : %s.getValue());", q.Number, q.Name, q.Name)
+ }
+ return fmt.Sprintf("stmt.setString(%d, %s.getValue());", q.Number, q.Name)
+ }
+
return fmt.Sprintf("stmt.setObject(%d, %s);", q.Number, q.Name)
}
@@ -107,7 +123,6 @@ func (q QueryReturn) ResultStmt(number int) string {
typeOnly := q.JavaType.Type[strings.LastIndex(q.JavaType.Type, ".")+1:]
if q.JavaType.IsList {
- // TODO - check for nullable array support
if q.JavaType.IsNullable {
return fmt.Sprintf("getList(results, %d, %s[].class)", number, typeOnly)
}
@@ -127,6 +142,13 @@ func (q QueryReturn) ResultStmt(number int) string {
return fmt.Sprintf("results.get%s(%d)", typeOnly, number)
}
+ if q.JavaType.IsEnum {
+ if q.JavaType.IsNullable {
+ return fmt.Sprintf("Optional.ofNullable(results.getString(%d)).map(%s::fromValue).orElse(null)", number, typeOnly)
+ }
+ return fmt.Sprintf("%s.fromValue(results.getString(%d))", typeOnly, number)
+ }
+
return fmt.Sprintf("results.getObject(%d, %s.class)", number, typeOnly)
}
@@ -149,7 +171,15 @@ type NullableHelpers struct {
List bool
}
+type Enum struct {
+ Schema string
+ Name string
+ Values []string
+}
+
type (
Queries map[string][]Query
EmbeddedModels map[string][]QueryReturn
+ // Enums is a map of "schema_name.enum_name" to enum value.
+ Enums map[string]Enum
)
diff --git a/internal/gen.go b/internal/gen.go
index bb24b3e..57dc64b 100644
--- a/internal/gen.go
+++ b/internal/gen.go
@@ -25,8 +25,57 @@ var (
postgresPlaceholderRegexp = regexp.MustCompile(`\B\$\d+\b`)
)
-func fixQueryPlaceholders(engine, query string) (string, error) {
- if engine != "postgresql" {
+type JavaGenerator struct {
+ req *plugin.GenerateRequest
+ conf core.Config
+
+ queries core.Queries
+ models core.EmbeddedModels
+
+ enums core.Enums
+ usedEnums []string
+
+ typeConversionFunc sqltypes.TypeConversionFunc
+ nullableHelpers core.NullableHelpers
+}
+
+func NewJavaGenerator(req *plugin.GenerateRequest) (*JavaGenerator, error) {
+ conf := core.Config{
+ IndentChar: defaultIndentChar,
+ CharsPerIndentLevel: defaultCharsPerIndentLevel,
+ NullableAnnotation: "org.jspecify.annotations.Nullable",
+ NonNullAnnotation: "org.jspecify.annotations.NonNull",
+ }
+ if len(req.PluginOptions) > 0 {
+ if err := json.Unmarshal(req.PluginOptions, &conf); err != nil {
+ return nil, err
+ }
+ }
+
+ var typeConversionFunc sqltypes.TypeConversionFunc
+ switch req.Settings.Engine {
+ case "postgresql":
+ typeConversionFunc = sqltypes.PostgresTypeToJavaType
+ case "mysql":
+ typeConversionFunc = sqltypes.MysqlTypeToJavaType
+ default:
+ return nil, fmt.Errorf("engine %q is not supported", req.Settings.Engine)
+ }
+
+ return &JavaGenerator{
+ req: req,
+ conf: conf,
+ queries: make(core.Queries),
+ models: make(core.EmbeddedModels),
+ enums: make(core.Enums),
+ usedEnums: make([]string, 0),
+ typeConversionFunc: typeConversionFunc,
+ nullableHelpers: core.NullableHelpers{},
+ }, nil
+}
+
+func (gen *JavaGenerator) fixQueryPlaceholders(query string) (string, error) {
+ if gen.req.Settings.Engine != "postgresql" {
return query, nil
}
@@ -45,10 +94,23 @@ func fixQueryPlaceholders(engine, query string) (string, error) {
return newQuery, nil
}
-func parseQueryReturn(tcf sqltypes.TypeConversionFunc, nullableHelpers *core.NullableHelpers, col *plugin.Column) (*core.QueryReturn, error) {
- strJavaType, err := tcf(col.Type)
+func (gen *JavaGenerator) parseQueryReturn(col *plugin.Column) (*core.QueryReturn, error) {
+ isEnum := false
+ strJavaType, err := gen.typeConversionFunc(col.Type)
if err != nil {
- return nil, err
+ schema := col.Table.Schema
+ if schema == "" {
+ schema = gen.req.Catalog.DefaultSchema
+ }
+
+ enumQualifiedName := fmt.Sprintf("%s.%s", schema, col.Type.Name)
+ if _, ok := gen.enums[enumQualifiedName]; !ok {
+ return nil, err
+ }
+
+ gen.usedEnums = append(gen.usedEnums, enumQualifiedName)
+ strJavaType = gen.conf.Package + ".enums." + codegen.EnumClassName(enumQualifiedName, gen.req.Catalog.DefaultSchema)
+ isEnum = true
}
if col.ArrayDims > 1 {
@@ -60,23 +122,24 @@ func parseQueryReturn(tcf sqltypes.TypeConversionFunc, nullableHelpers *core.Nul
Type: strJavaType,
IsList: col.IsArray,
IsNullable: !col.NotNull,
+ IsEnum: isEnum,
}
if javaType.IsNullable {
if javaType.IsList {
- nullableHelpers.List = true
+ gen.nullableHelpers.List = true
} else {
switch strJavaType {
case "Integer":
- nullableHelpers.Int = true
+ gen.nullableHelpers.Int = true
case "Long":
- nullableHelpers.Long = true
+ gen.nullableHelpers.Long = true
case "Float":
- nullableHelpers.Float = true
+ gen.nullableHelpers.Float = true
case "Double":
- nullableHelpers.Double = true
+ gen.nullableHelpers.Double = true
case "Boolean":
- nullableHelpers.Boolean = true
+ gen.nullableHelpers.Boolean = true
}
}
}
@@ -87,39 +150,22 @@ func parseQueryReturn(tcf sqltypes.TypeConversionFunc, nullableHelpers *core.Nul
}, nil
}
-func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) {
- conf := core.Config{
- IndentChar: defaultIndentChar,
- CharsPerIndentLevel: defaultCharsPerIndentLevel,
- NullableAnnotation: "org.jspecify.annotations.Nullable",
- NonNullAnnotation: "org.jspecify.annotations.NonNull",
- }
- if len(req.PluginOptions) > 0 {
- if err := json.Unmarshal(req.PluginOptions, &conf); err != nil {
- return nil, err
+func (gen *JavaGenerator) Run() (*plugin.GenerateResponse, error) {
+ // parse out the enums from the generate request
+ for _, schema := range gen.req.Catalog.Schemas {
+ for _, enum := range schema.Enums {
+ gen.enums[fmt.Sprintf("%s.%s", schema.Name, enum.Name)] = core.Enum{
+ Schema: schema.Name,
+ Name: enum.Name,
+ Values: enum.Vals,
+ }
}
}
- if conf.Package == "" {
- return nil, fmt.Errorf("'package' is a required configuration option")
- }
-
- var typeConversionFunc sqltypes.TypeConversionFunc
- switch req.Settings.Engine {
- case "postgresql":
- typeConversionFunc = sqltypes.PostgresTypeToJavaType
- default:
- return nil, fmt.Errorf("engine %q is not supported", req.Settings.Engine)
- }
-
- var queries core.Queries = make(map[string][]core.Query)
- var embeddedModels core.EmbeddedModels = make(map[string][]core.QueryReturn)
- nullableHelpers := core.NullableHelpers{}
-
// parse the incoming generate request into our Queries type
- for _, query := range req.Queries {
- if _, ok := queries[query.Filename]; !ok {
- queries[query.Filename] = make([]core.Query, 0)
+ for _, query := range gen.req.Queries {
+ if _, ok := gen.queries[query.Filename]; !ok {
+ gen.queries[query.Filename] = make([]core.Query, 0)
}
command, err := core.QueryCommandFor(query.Cmd)
@@ -130,9 +176,23 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
// TODO - enum types? other specialness?
args := make([]core.QueryArg, 0)
for _, arg := range query.Params {
- javaType, err := typeConversionFunc(arg.Column.Type)
+ isEnum := false
+ javaType, err := gen.typeConversionFunc(arg.Column.Type)
if err != nil {
- return nil, err
+ // check if this is an enum type
+ schema := arg.Column.Table.Schema
+ if schema == "" {
+ schema = gen.req.Catalog.DefaultSchema
+ }
+
+ enumQualifiedName := fmt.Sprintf("%s.%s", schema, arg.Column.Type.Name)
+ if _, ok := gen.enums[enumQualifiedName]; !ok {
+ return nil, err
+ }
+
+ gen.usedEnums = append(gen.usedEnums, enumQualifiedName)
+ javaType = gen.conf.Package + ".enums." + codegen.EnumClassName(enumQualifiedName, gen.req.Catalog.DefaultSchema)
+ isEnum = true
}
if arg.Column.ArrayDims > 1 {
@@ -145,8 +205,9 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
JavaType: core.JavaType{
SqlType: sdk.DataType(arg.Column.Type),
Type: javaType,
- IsList: arg.Column.IsArray, // TODO check this will always be present
+ IsList: arg.Column.IsArray,
IsNullable: !arg.Column.NotNull,
+ IsEnum: isEnum,
},
})
}
@@ -156,7 +217,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
for _, ret := range query.Columns {
if ret.EmbedTable == nil {
// normal types
- qr, err := parseQueryReturn(typeConversionFunc, &nullableHelpers, ret)
+ qr, err := gen.parseQueryReturn(ret)
if err != nil {
return nil, errors.Join(errors.New("failed to parse query return column"), err)
}
@@ -169,12 +230,12 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
var table *plugin.Table
// find the catalog entry for the embedded table
- schema := req.Catalog.DefaultSchema
+ schema := gen.req.Catalog.DefaultSchema
if ret.EmbedTable.Schema != "" {
schema = ret.EmbedTable.Schema
}
- for _, s := range req.Catalog.Schemas {
+ for _, s := range gen.req.Catalog.Schemas {
if s.Name != schema {
continue
}
@@ -192,15 +253,15 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
// TODO - fix type-writer to only exclude items that aren't part of the package name
modelName := strcase.ToCamel(table.Rel.Name)
- if !conf.EmitExactTableNames {
- modelName = strcase.ToCamel(inflection.Singular(table.Rel.Name, conf.InflectionExcludeTableNames))
+ if !gen.conf.EmitExactTableNames {
+ modelName = strcase.ToCamel(inflection.Singular(table.Rel.Name, gen.conf.InflectionExcludeTableNames))
}
// check if we already have an entry for this model
- if _, ok := embeddedModels[modelName]; !ok {
+ if _, ok := gen.models[modelName]; !ok {
var modelParams []core.QueryReturn
for _, c := range table.Columns {
- qr, err := parseQueryReturn(typeConversionFunc, &nullableHelpers, c)
+ qr, err := gen.parseQueryReturn(c)
if err != nil {
return nil, errors.Join(errors.New("failed to parse query return column"), err)
}
@@ -208,7 +269,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
modelParams = append(modelParams, *qr)
}
- embeddedModels[modelName] = modelParams
+ gen.models[modelName] = modelParams
}
returns = append(returns, core.QueryReturn{
@@ -216,7 +277,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
JavaType: core.JavaType{
SqlType: "",
// we don't need to specify package here - models file will be generated in the same location as the queries file
- Type: conf.Package + ".models." + modelName,
+ Type: gen.conf.Package + ".models." + modelName,
IsList: false, // TODO - check: this *should* be impossible
IsNullable: false, // TODO - check: empty record should be output instead
},
@@ -225,12 +286,12 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
}
// TODO - look into fixing ? operator for postgresql JSONB operations maybe
- newQueryText, err := fixQueryPlaceholders(req.Settings.Engine, query.Text)
+ newQueryText, err := gen.fixQueryPlaceholders(query.Text)
if err != nil {
return nil, err
}
- queries[query.Filename] = append(queries[query.Filename], core.Query{
+ gen.queries[query.Filename] = append(gen.queries[query.Filename], core.Query{
RawCommand: query.Cmd,
Command: command,
Text: newQueryText,
@@ -243,12 +304,12 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
}
outputFiles := make([]*plugin.File, 0)
- // order the queries for each file alphabetically
- for file := range queries {
- slices.SortFunc(queries[file], func(a, b core.Query) int { return strings.Compare(a.MethodName, b.MethodName) })
+ for file := range gen.queries {
+ // order the queries for each file alphabetically
+ slices.SortFunc(gen.queries[file], func(a, b core.Query) int { return strings.Compare(a.MethodName, b.MethodName) })
// build the queries file contents
- fileName, fileContents, err := codegen.BuildQueriesFile(conf, file, queries[file], embeddedModels, nullableHelpers)
+ fileName, fileContents, err := codegen.BuildQueriesFile(gen.req.Settings.Engine, gen.conf, file, gen.queries[file], gen.models, gen.nullableHelpers)
if err != nil {
return nil, err
}
@@ -258,8 +319,27 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
})
}
- for modelName, model := range embeddedModels {
- fileName, fileContents, err := codegen.BuildModelFile(conf, modelName, model)
+ for modelName, model := range gen.models {
+ fileName, fileContents, err := codegen.BuildModelFile(gen.conf, modelName, model)
+ if err != nil {
+ return nil, err
+ }
+ outputFiles = append(outputFiles, &plugin.File{
+ Name: fileName,
+ Contents: fileContents,
+ })
+ }
+
+ // remove duplicate enum entries
+ slices.Sort(gen.usedEnums)
+ slices.Compact(gen.usedEnums)
+ for _, qualName := range gen.usedEnums {
+ if qualName == "" {
+ continue
+ }
+
+ enum := gen.enums[qualName]
+ fileName, fileContents, err := codegen.BuildEnumFile(gen.req.Settings.Engine, gen.conf, qualName, enum, gen.req.Catalog.DefaultSchema)
if err != nil {
return nil, err
}
@@ -271,3 +351,13 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
return &plugin.GenerateResponse{Files: outputFiles}, nil
}
+
+// TODO - check if the context is actually important for anything
+func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) {
+ jg, err := NewJavaGenerator(req)
+ if err != nil {
+ return nil, err
+ }
+
+ return jg.Run()
+}
diff --git a/internal/sqltypes/mysql.go b/internal/sqltypes/mysql.go
new file mode 100644
index 0000000..5b9b9a8
--- /dev/null
+++ b/internal/sqltypes/mysql.go
@@ -0,0 +1,40 @@
+package sqltypes
+
+import (
+ "fmt"
+ "github.com/sqlc-dev/plugin-sdk-go/plugin"
+ "github.com/sqlc-dev/plugin-sdk-go/sdk"
+)
+
+func MysqlTypeToJavaType(identifier *plugin.Identifier) (string, error) {
+ colType := sdk.DataType(identifier)
+
+ switch colType {
+ case "varchar", "text", "char", "tinytext", "mediumtext", "longtext":
+ return "String", nil
+ case "int", "integer", "smallint", "mediumint", "year":
+ return "Integer", nil
+ case "bigint":
+ return "Long", nil
+ case "blob", "binary", "varbinary", "tinyblob", "mediumblob", "longblob":
+ return "byte[]", nil
+ case "double", "double precision", "real":
+ return "Double", nil
+ // TODO - why does gen-kotlin use string for this? look into a better solution
+ case "decimal", "dec", "fixed":
+ return "String", nil
+ case "date":
+ return "java.time.LocalDate", nil
+ case "datetime", "time":
+ return "java.time.LocalDateTime", nil
+ // TODO - instant support - look into option for this in pgsql as well
+ case "timestamp":
+ return "java.time.OffsetDateTime", nil
+ case "boolean", "bool", "tinyint":
+ return "Boolean", nil
+ case "json":
+ return "String", nil
+ default:
+ return "", fmt.Errorf("datatype '%s' not currently supported", colType)
+ }
+}
diff --git a/internal/sqltypes/postgresql.go b/internal/sqltypes/postgresql.go
index c44d32f..042a5c8 100644
--- a/internal/sqltypes/postgresql.go
+++ b/internal/sqltypes/postgresql.go
@@ -24,9 +24,6 @@ func PostgresTypeToJavaType(identifier *plugin.Identifier) (string, error) {
return "java.math.BigDecimal", nil
case "bool", "pg_catalog.bool":
return "Boolean", nil
- // TODO - figure out if this can be supported properly
- case "jsonb":
- return "String", nil
case "bytea", "blob", "pg_catalog.bytea":
return "byte[]", nil
case "date":
@@ -41,10 +38,11 @@ func PostgresTypeToJavaType(identifier *plugin.Identifier) (string, error) {
return "String", nil
case "uuid":
return "java.util.UUID", nil
- case "inet", "void", "any":
- return "", fmt.Errorf("datatype '%s' not currently supported", colType)
+ // TODO - figure out if these can be supported properly
+ case "jsonb", "inet":
+ return "String", nil
default:
- // TODO - deal with enums somehow
+ // void, any
return "", fmt.Errorf("datatype '%s' not currently supported", colType)
}
}
diff --git a/tests/pom.xml b/tests/pom.xml
index 2c74b14..4b1aff8 100644
--- a/tests/pom.xml
+++ b/tests/pom.xml
@@ -27,16 +27,23 @@
+
org.jspecify
jspecify
1.0.0
+
org.postgresql
postgresql
42.7.5
+
+ com.mysql
+ mysql-connector-j
+ 9.2.0
+
org.junit.jupiter
@@ -49,6 +56,7 @@
assertj-core
3.27.3
+
org.testcontainers
testcontainers
@@ -64,6 +72,11 @@
postgresql
test
+
+ org.testcontainers
+ mysql
+ test
+
diff --git a/tests/sqlc.yaml b/tests/sqlc.yaml
index 1431440..e045e6c 100644
--- a/tests/sqlc.yaml
+++ b/tests/sqlc.yaml
@@ -3,7 +3,7 @@ plugins:
- name: java
wasm:
url: file://sqlc-gen-java.wasm
- sha256: a2211754142fd64604a7b5b2cbec985f685830ca07db860974a7996f6c06fb5b
+ sha256: 1e30725cfd253e18486ebff9ba7eea481bc2680d946ecae85d308de39485b921
sql:
- schema: src/main/resources/postgres/schema.sql
queries: src/main/resources/postgres/queries.sql
@@ -13,3 +13,11 @@ sql:
plugin: java
options:
package: io.github.tandemdude.sgj.postgres
+ - schema: src/main/resources/mysql/schema.sql
+ queries: src/main/resources/mysql/queries.sql
+ engine: mysql
+ codegen:
+ - out: src/main/java/io/github/tandemdude/sgj/mysql
+ plugin: java
+ options:
+ package: io.github.tandemdude.sgj.mysql
diff --git a/tests/src/main/resources/mysql/queries.sql b/tests/src/main/resources/mysql/queries.sql
new file mode 100644
index 0000000..b235fa1
--- /dev/null
+++ b/tests/src/main/resources/mysql/queries.sql
@@ -0,0 +1,60 @@
+-- name: GetAuthor :one
+SELECT * FROM authors
+WHERE author_id = ?;
+
+-- name: GetBook :one
+SELECT * FROM books
+WHERE book_id = ?;
+
+-- name: DeleteBook :exec
+DELETE FROM books
+WHERE book_id = ?;
+
+-- name: BooksByTitleYear :many
+SELECT * FROM books
+WHERE title = ? AND yr = ?;
+
+-- name: BooksByTags :many
+SELECT
+ book_id,
+ title,
+ name,
+ isbn,
+ tags
+FROM books
+ LEFT JOIN authors ON books.author_id = authors.author_id
+WHERE tags = ?;
+
+-- name: CreateAuthor :execresult
+INSERT INTO authors (name) VALUES (?);
+
+-- name: CreateBook :execresult
+INSERT INTO books (
+ author_id,
+ isbn,
+ book_type,
+ title,
+ yr,
+ available,
+ tags
+) VALUES (?, ?, ?, ?, ?, ?, ?);
+
+-- name: UpdateBook :exec
+UPDATE books
+SET title = ?, tags = ?
+WHERE book_id = ?;
+
+-- name: UpdateBookISBN :exec
+UPDATE books
+SET title = ?, tags = ?, isbn = ?
+WHERE book_id = ?;
+
+-- name: DeleteAuthorBeforeYear :exec
+DELETE FROM books
+WHERE yr < ? AND author_id = ?;
+
+-- name: CreateEnumRow :execresult
+INSERT INTO nullable_enum_test (enum_field) VALUES (?);
+
+-- name: GetEnumRow :one
+SELECT * FROM nullable_enum_test WHERE t_id = ?;
\ No newline at end of file
diff --git a/tests/src/main/resources/mysql/schema.sql b/tests/src/main/resources/mysql/schema.sql
new file mode 100644
index 0000000..ed6b2aa
--- /dev/null
+++ b/tests/src/main/resources/mysql/schema.sql
@@ -0,0 +1,26 @@
+-- taken from sqlc-gen-kotlin, booktest --
+CREATE TABLE authors (
+ author_id integer NOT NULL AUTO_INCREMENT PRIMARY KEY,
+ name text NOT NULL
+) ENGINE=InnoDB;
+
+CREATE INDEX authors_name_idx ON authors(name(255));
+
+CREATE TABLE books (
+ book_id integer NOT NULL AUTO_INCREMENT PRIMARY KEY,
+ author_id integer NOT NULL,
+ isbn varchar(255) NOT NULL DEFAULT '' UNIQUE,
+ book_type ENUM('FICTION', 'NONFICTION') NOT NULL DEFAULT 'FICTION',
+ title text NOT NULL,
+ yr integer NOT NULL DEFAULT 2000,
+ available datetime NOT NULL DEFAULT NOW(),
+ tags text NOT NULL
+) ENGINE=InnoDB;
+
+CREATE INDEX books_title_idx ON books(title(255), yr);
+-- end --
+
+CREATE TABLE nullable_enum_test (
+ t_id integer NOT NULL AUTO_INCREMENT PRIMARY KEY,
+ enum_field ENUM('foo', 'bar') DEFAULT NULL
+);
diff --git a/tests/src/main/resources/postgres/queries.sql b/tests/src/main/resources/postgres/queries.sql
index 6fa0a75..36dc980 100644
--- a/tests/src/main/resources/postgres/queries.sql
+++ b/tests/src/main/resources/postgres/queries.sql
@@ -49,3 +49,9 @@ RETURNING row_id;
-- name: GetBytes :one
SELECT * FROM bytes WHERE row_id = $1;
+
+-- name: CreatePerson :exec
+INSERT INTO person(name, current_mood, next_mood) VALUES ($1, $2, $3);
+
+-- name: GetPerson :one
+SELECT * FROM person WHERE name = $1;
diff --git a/tests/src/main/resources/postgres/schema.sql b/tests/src/main/resources/postgres/schema.sql
index c7c0d4d..663f1d4 100644
--- a/tests/src/main/resources/postgres/schema.sql
+++ b/tests/src/main/resources/postgres/schema.sql
@@ -38,3 +38,10 @@ CREATE TABLE bytes (
contents BYTEA NOT NULL,
hash BYTEA DEFAULT NULL
);
+
+CREATE TYPE mood AS ENUM ('sad', 'happy', 'ok');
+CREATE TABLE person (
+ name TEXT PRIMARY KEY,
+ current_mood mood NOT NULL,
+ next_mood mood DEFAULT NULL
+);
diff --git a/tests/src/test/java/io/github/tandemdude/sgj/mysql/TestQueries.java b/tests/src/test/java/io/github/tandemdude/sgj/mysql/TestQueries.java
new file mode 100644
index 0000000..c629bb7
--- /dev/null
+++ b/tests/src/test/java/io/github/tandemdude/sgj/mysql/TestQueries.java
@@ -0,0 +1,63 @@
+package io.github.tandemdude.sgj.mysql;
+
+import io.github.tandemdude.sgj.mysql.enums.BooksBookType;
+import io.github.tandemdude.sgj.mysql.enums.NullableEnumTestEnumField;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+import org.testcontainers.containers.MySQLContainer;
+import org.testcontainers.junit.jupiter.Container;
+import org.testcontainers.junit.jupiter.Testcontainers;
+
+import java.sql.Connection;
+import java.sql.DriverManager;
+import java.sql.SQLException;
+import java.time.LocalDateTime;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+@Testcontainers
+public class TestQueries {
+ @Container
+ private final MySQLContainer> mysql = new MySQLContainer<>("mysql:latest")
+ .withInitScript("mysql/schema.sql");
+
+ Connection getConn() throws SQLException {
+ var conn = DriverManager.getConnection(mysql.getJdbcUrl(), mysql.getUsername(), mysql.getPassword());
+ conn.setAutoCommit(true);
+ return conn;
+ }
+
+ @Test
+ @DisplayName("enum types can be read and written")
+ void enumTypesCanBeReadAndWritten() throws Exception {
+ try (var conn = getConn()) {
+ var q = new Queries(conn);
+
+ var authorId = q.createAuthor("foo");
+ var bookId = q.createBook((int) authorId, "foo", BooksBookType.FICTION, "bar", 2000, LocalDateTime.now(), "baz");
+
+ var foundRow = q.getBook((int) bookId);
+ assertThat(foundRow).isPresent();
+ var found = foundRow.get();
+ assertThat(found.bookType()).isEqualTo(BooksBookType.FICTION);
+ }
+ }
+
+ @Test
+ @DisplayName("nullable enum types can be read and written")
+ void nullableEnumTypesCanBeReadAndWritten() throws Exception {
+ try (var conn = getConn()) {
+ var q = new Queries(conn);
+
+ var id = q.createEnumRow(NullableEnumTestEnumField.BAR);
+ var foundRow = q.getEnumRow((int) id);
+ assertThat(foundRow).isPresent();
+ assertThat(foundRow.get().enumField()).isEqualTo(NullableEnumTestEnumField.BAR);
+
+ var id2 = q.createEnumRow(null);
+ var foundRow2 = q.getEnumRow((int) id2);
+ assertThat(foundRow2).isPresent();
+ assertThat(foundRow2.get().enumField()).isNull();
+ }
+ }
+}
diff --git a/tests/src/test/java/io/github/tandemdude/sgj/postgres/TestQueries.java b/tests/src/test/java/io/github/tandemdude/sgj/postgres/TestQueries.java
index 88d7401..07d4a1c 100644
--- a/tests/src/test/java/io/github/tandemdude/sgj/postgres/TestQueries.java
+++ b/tests/src/test/java/io/github/tandemdude/sgj/postgres/TestQueries.java
@@ -1,6 +1,6 @@
package io.github.tandemdude.sgj.postgres;
-import org.jspecify.annotations.NonNull;
+import io.github.tandemdude.sgj.postgres.enums.Mood;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.DisplayName;
import org.testcontainers.containers.PostgreSQLContainer;
@@ -32,7 +32,7 @@ Connection getConn() throws SQLException {
@Test
@DisplayName("GetUser returns empty optional when no records found")
- public void getUserReturnsEmptyOptionalNoRecordsFound() throws Exception {
+ void getUserReturnsEmptyOptionalNoRecordsFound() throws Exception {
try (var conn = getConn()) {
var q = new Queries(conn);
@@ -42,7 +42,7 @@ public void getUserReturnsEmptyOptionalNoRecordsFound() throws Exception {
@Test
@DisplayName("GetUser returns populated optional when record found")
- public void getUserReturnsPopulatedOptionalRecordFound() throws Exception {
+ void getUserReturnsPopulatedOptionalRecordFound() throws Exception {
try (var conn = getConn()) {
var q = new Queries(conn);
@@ -59,7 +59,7 @@ public void getUserReturnsPopulatedOptionalRecordFound() throws Exception {
@Test
@DisplayName("ListUsers returns empty list when no records found")
- public void listUsersReturnsEmptyListNoRecordsFound() throws Exception {
+ void listUsersReturnsEmptyListNoRecordsFound() throws Exception {
try (var conn = getConn()) {
var q = new Queries(conn);
@@ -69,7 +69,7 @@ public void listUsersReturnsEmptyListNoRecordsFound() throws Exception {
@Test
@DisplayName("ListUsers returns populated list when records found")
- public void listUsersReturnsPopulatedListRecordsFound() throws Exception {
+ void listUsersReturnsPopulatedListRecordsFound() throws Exception {
try (var conn = getConn()) {
var q = new Queries(conn);
@@ -84,7 +84,7 @@ public void listUsersReturnsPopulatedListRecordsFound() throws Exception {
@Test
@DisplayName("GetUserDup throws error when multiple records returned")
- public void getUserDupReturnsErrorWhenMultipleRecordsReturned() throws Exception {
+ void getUserDupReturnsErrorWhenMultipleRecordsReturned() throws Exception {
try (var conn = getConn()) {
var q = new Queries(conn);
@@ -99,7 +99,7 @@ public void getUserDupReturnsErrorWhenMultipleRecordsReturned() throws Exception
@Test
@DisplayName("CreateMessage processes input list correctly")
- public void createMessageProcessesInputListCorrectly() throws Exception {
+ void createMessageProcessesInputListCorrectly() throws Exception {
try (var conn = getConn()) {
var q = new Queries(conn);
@@ -114,7 +114,7 @@ public void createMessageProcessesInputListCorrectly() throws Exception {
@Test
@DisplayName("GetMessage works when attachments is null")
- public void getMessageWorksWhenAttachmentsIsNull() throws Exception {
+ void getMessageWorksWhenAttachmentsIsNull() throws Exception {
try (var conn = getConn()) {
var q = new Queries(conn);
@@ -129,7 +129,7 @@ public void getMessageWorksWhenAttachmentsIsNull() throws Exception {
@Test
@DisplayName("GetUserAndToken returns embedded objects")
- public void getUserAndTokenReturnsEmbeddedObjects() throws Exception {
+ void getUserAndTokenReturnsEmbeddedObjects() throws Exception {
try (var conn = getConn()) {
var q = new Queries(conn);
@@ -146,7 +146,7 @@ public void getUserAndTokenReturnsEmbeddedObjects() throws Exception {
@Test
@DisplayName("GetBytes returns same data as during creation")
- public void getBytesReturnsSameDataAsDuringCreation() throws Exception {
+ void getBytesReturnsSameDataAsDuringCreation() throws Exception {
try (var conn = getConn()) {
var q = new Queries(conn);
@@ -169,4 +169,24 @@ public void getBytesReturnsSameDataAsDuringCreation() throws Exception {
assertThat(found2.get().hash()).isEqualTo(s2);
}
}
+
+ @Test
+ @DisplayName("nullable and non-nullable enums work correctly")
+ void nullableAndNonNullableEnumsWorkCorrectly() throws Exception {
+ try (var conn = getConn()) {
+ var q = new Queries(conn);
+
+ q.createPerson("foo", Mood.HAPPY, null);
+ var p1 = q.getPerson("foo");
+ assertThat(p1).isPresent();
+ assertThat(p1.get().currentMood()).isEqualTo(Mood.HAPPY);
+ assertThat(p1.get().nextMood()).isNull();
+
+ q.createPerson("bar", Mood.HAPPY, Mood.OK);
+ var p2 = q.getPerson("bar");
+ assertThat(p2).isPresent();
+ assertThat(p2.get().currentMood()).isEqualTo(Mood.HAPPY);
+ assertThat(p2.get().nextMood()).isEqualTo(Mood.OK);
+ }
+ }
}