diff --git a/README.md b/README.md index d9905b0..c3488e2 100644 --- a/README.md +++ b/README.md @@ -62,12 +62,10 @@ You should ensure that the `sha256` value in your `sqlc.yaml` is correct for thi ## Planned Features -- `MySQL` support - `SQLite` support - Improved parameter naming **Tentative:** - r2dbc support -- Support for PostgreSQL enum types - copyfrom support where possible [ref](https://www.baeldung.com/jdbc-batch-processing) diff --git a/internal/codegen/enums.go b/internal/codegen/enums.go new file mode 100644 index 0000000..9846212 --- /dev/null +++ b/internal/codegen/enums.go @@ -0,0 +1,75 @@ +package codegen + +import ( + "fmt" + "github.com/iancoleman/strcase" + "github.com/tandemdude/sqlc-gen-java/internal/core" + "regexp" + "strings" + "unicode" + "unicode/utf8" +) + +var javaInvalidIdentChars = regexp.MustCompile("[^$\\w]") + +func EnumClassName(qualifiedName, defaultSchema string) string { + return strcase.ToCamel(strings.TrimPrefix(qualifiedName, defaultSchema+".")) +} + +func enumValueName(value string) string { + rep := strings.NewReplacer("-", "_", ":", "_", "/", "_", ".", "_") + name := rep.Replace(value) + name = strings.ToUpper(name) + name = javaInvalidIdentChars.ReplaceAllString(name, "") + + r, _ := utf8.DecodeRuneInString(name) + if unicode.IsDigit(r) { + name = "_" + name + } + return name +} + +func BuildEnumFile(engine string, conf core.Config, qualName string, enum core.Enum, defaultSchema string) (string, []byte, error) { + className := EnumClassName(qualName, defaultSchema) + + sb := IndentStringBuilder{indentChar: conf.IndentChar, charsPerIndentLevel: conf.CharsPerIndentLevel} + sb.writeSqlcHeader() + sb.WriteString("\n") + sb.WriteString("package " + conf.Package + ".enums;\n") + sb.WriteString("\n") + sb.WriteString("import javax.annotation.processing.Generated;\n") + sb.WriteString("\n") + sb.WriteString("@Generated(\"io.github.tandemdude.sqlc-gen-java\")\n") + sb.WriteString("public enum " + className + " {\n") + + if engine == "mysql" { + sb.WriteIndentedString(1, "BLANK(\"\"),\n") + } + + // write other values + for i, value := range enum.Values { + name := enumValueName(value) + sb.WriteIndentedString(1, fmt.Sprintf("%s(\"%s\")", name, value)) + + if i < len(enum.Values)-1 { + sb.WriteString(",\n") + } + } + sb.WriteString(";\n\n") + sb.WriteIndentedString(1, "private final String value;\n\n") + sb.WriteIndentedString(1, className+"(final String value) {\n") + sb.WriteIndentedString(2, "this.value = value;\n") + sb.WriteIndentedString(1, "}\n\n") + sb.WriteIndentedString(1, "public String getValue() {\n") + sb.WriteIndentedString(2, "return this.value;") + sb.WriteIndentedString(1, "}\n\n") + sb.WriteIndentedString(1, "public static "+className+" fromValue(final String value) {\n") + sb.WriteIndentedString(2, "for (var v : "+className+".values()) {\n") + sb.WriteIndentedString(3, "if (v.value.equals(value)) return v;\n") + sb.WriteIndentedString(2, "}\n") + sb.WriteIndentedString(2, "throw new IllegalArgumentException(\"No enum constant with value \" + value);\n") + sb.WriteIndentedString(1, "}\n") + sb.WriteString("}\n") + + return fmt.Sprintf("enums/%s.java", className), []byte(sb.String()), nil +} diff --git a/internal/codegen/models.go b/internal/codegen/models.go index 2940020..3b586bd 100644 --- a/internal/codegen/models.go +++ b/internal/codegen/models.go @@ -28,9 +28,12 @@ func BuildModelFile(config core.Config, name string, model []core.QueryReturn) ( header.WriteString("\n") header.WriteString("package " + config.Package + ".models;\n") header.WriteString("\n") + header.WriteString("import javax.annotation.processing.Generated;\n") + header.WriteString("\n") body := NewIndentStringBuilder(config.IndentChar, config.CharsPerIndentLevel) body.WriteString("\n") + body.WriteString("@Generated(\"io.github.tandemdude.sqlc-gen-java\")\n") body.WriteString("public record " + strcase.ToCamel(name) + "(\n") for i, ret := range model { imps, err := body.writeParameter(ret.JavaType, ret.Name, nonNullAnnotation, nullableAnnotation) diff --git a/internal/codegen/queries.go b/internal/codegen/queries.go index 02470f1..f3d68f8 100644 --- a/internal/codegen/queries.go +++ b/internal/codegen/queries.go @@ -119,14 +119,14 @@ func completeMethodBody(sb *IndentStringBuilder, q core.Query, embeddedModels co } } -func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Query, embeddedModels core.EmbeddedModels, nullableHelpers core.NullableHelpers) (string, []byte, error) { +func BuildQueriesFile(engine string, config core.Config, queryFilename string, queries []core.Query, embeddedModels core.EmbeddedModels, nullableHelpers core.NullableHelpers) (string, []byte, error) { className := strcase.ToCamel(strings.TrimSuffix(queryFilename, ".sql")) className = strings.TrimSuffix(className, "Query") className = strings.TrimSuffix(className, "Queries") className += "Queries" imports := make([]string, 0) - imports = append(imports, "java.sql.SQLException", "java.sql.ResultSet", "java.util.Arrays") + imports = append(imports, "java.sql.SQLException", "java.sql.ResultSet", "java.util.Arrays", "javax.annotation.processing.Generated") var nonNullAnnotation string if config.NonNullAnnotation != "" { @@ -148,6 +148,7 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q body := NewIndentStringBuilder(config.IndentChar, config.CharsPerIndentLevel) body.WriteString("\n") // Add the class declaration and constructor + body.WriteString("@Generated(\"io.github.tandemdude.sqlc-gen-java\")\n") body.WriteString("public class " + className + " {\n") body.WriteIndentedString(1, "private final java.sql.Connection conn;\n\n") body.WriteIndentedString(1, "public "+className+"(java.sql.Connection conn) {\n") @@ -238,7 +239,11 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q } methodBody := NewIndentStringBuilder(config.IndentChar, config.CharsPerIndentLevel) - methodBody.WriteIndentedString(2, "var stmt = conn.prepareStatement("+q.MethodName+");\n") + if q.Command == core.ExecResult { + methodBody.WriteIndentedString(2, "var stmt = conn.prepareStatement("+q.MethodName+", java.sql.Statement.RETURN_GENERATED_KEYS);\n") + } else { + methodBody.WriteIndentedString(2, "var stmt = conn.prepareStatement("+q.MethodName+");\n") + } // write the method signature body.WriteString("\n") @@ -259,7 +264,7 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q body.WriteString(",\n") } - methodBody.WriteIndentedString(2, arg.BindStmt()+"\n") + methodBody.WriteIndentedString(2, arg.BindStmt(engine)+"\n") } body.WriteString("\n") body.WriteIndentedString(1, ") throws SQLException {\n") diff --git a/internal/core/models.go b/internal/core/models.go index 1aacbdc..c07bdc9 100644 --- a/internal/core/models.go +++ b/internal/core/models.go @@ -41,6 +41,7 @@ type JavaType struct { Type string IsList bool IsNullable bool + IsEnum bool } type QueryArg struct { @@ -68,7 +69,7 @@ var typeToJavaSqlTypeConst = map[string]string{ "Double": "DOUBLE", } -func (q QueryArg) BindStmt() string { +func (q QueryArg) BindStmt(engine string) string { typeOnly := q.JavaType.Type[strings.LastIndex(q.JavaType.Type, ".")+1:] if q.JavaType.IsList { @@ -94,6 +95,21 @@ func (q QueryArg) BindStmt() string { return fmt.Sprintf("%s == null ? stmt.setNull(%d, java.sql.Types.%s) : %s", q.Name, q.Number, javaSqlType, rawSet) } + if q.JavaType.IsEnum { + // postgres doesn't like it if you setString an enum directly unfortunately + if engine == "postgresql" { + if q.JavaType.IsNullable { + return fmt.Sprintf("stmt.setObject(%d, %s == null ? null : %s.getValue(), java.sql.Types.OTHER);", q.Number, q.Name, q.Name) + } + return fmt.Sprintf("stmt.setObject(%d, %s.getValue(), java.sql.Types.OTHER);", q.Number, q.Name) + } + + if q.JavaType.IsNullable { + return fmt.Sprintf("stmt.setString(%d, %s == null ? null : %s.getValue());", q.Number, q.Name, q.Name) + } + return fmt.Sprintf("stmt.setString(%d, %s.getValue());", q.Number, q.Name) + } + return fmt.Sprintf("stmt.setObject(%d, %s);", q.Number, q.Name) } @@ -107,7 +123,6 @@ func (q QueryReturn) ResultStmt(number int) string { typeOnly := q.JavaType.Type[strings.LastIndex(q.JavaType.Type, ".")+1:] if q.JavaType.IsList { - // TODO - check for nullable array support if q.JavaType.IsNullable { return fmt.Sprintf("getList(results, %d, %s[].class)", number, typeOnly) } @@ -127,6 +142,13 @@ func (q QueryReturn) ResultStmt(number int) string { return fmt.Sprintf("results.get%s(%d)", typeOnly, number) } + if q.JavaType.IsEnum { + if q.JavaType.IsNullable { + return fmt.Sprintf("Optional.ofNullable(results.getString(%d)).map(%s::fromValue).orElse(null)", number, typeOnly) + } + return fmt.Sprintf("%s.fromValue(results.getString(%d))", typeOnly, number) + } + return fmt.Sprintf("results.getObject(%d, %s.class)", number, typeOnly) } @@ -149,7 +171,15 @@ type NullableHelpers struct { List bool } +type Enum struct { + Schema string + Name string + Values []string +} + type ( Queries map[string][]Query EmbeddedModels map[string][]QueryReturn + // Enums is a map of "schema_name.enum_name" to enum value. + Enums map[string]Enum ) diff --git a/internal/gen.go b/internal/gen.go index bb24b3e..57dc64b 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -25,8 +25,57 @@ var ( postgresPlaceholderRegexp = regexp.MustCompile(`\B\$\d+\b`) ) -func fixQueryPlaceholders(engine, query string) (string, error) { - if engine != "postgresql" { +type JavaGenerator struct { + req *plugin.GenerateRequest + conf core.Config + + queries core.Queries + models core.EmbeddedModels + + enums core.Enums + usedEnums []string + + typeConversionFunc sqltypes.TypeConversionFunc + nullableHelpers core.NullableHelpers +} + +func NewJavaGenerator(req *plugin.GenerateRequest) (*JavaGenerator, error) { + conf := core.Config{ + IndentChar: defaultIndentChar, + CharsPerIndentLevel: defaultCharsPerIndentLevel, + NullableAnnotation: "org.jspecify.annotations.Nullable", + NonNullAnnotation: "org.jspecify.annotations.NonNull", + } + if len(req.PluginOptions) > 0 { + if err := json.Unmarshal(req.PluginOptions, &conf); err != nil { + return nil, err + } + } + + var typeConversionFunc sqltypes.TypeConversionFunc + switch req.Settings.Engine { + case "postgresql": + typeConversionFunc = sqltypes.PostgresTypeToJavaType + case "mysql": + typeConversionFunc = sqltypes.MysqlTypeToJavaType + default: + return nil, fmt.Errorf("engine %q is not supported", req.Settings.Engine) + } + + return &JavaGenerator{ + req: req, + conf: conf, + queries: make(core.Queries), + models: make(core.EmbeddedModels), + enums: make(core.Enums), + usedEnums: make([]string, 0), + typeConversionFunc: typeConversionFunc, + nullableHelpers: core.NullableHelpers{}, + }, nil +} + +func (gen *JavaGenerator) fixQueryPlaceholders(query string) (string, error) { + if gen.req.Settings.Engine != "postgresql" { return query, nil } @@ -45,10 +94,23 @@ func fixQueryPlaceholders(engine, query string) (string, error) { return newQuery, nil } -func parseQueryReturn(tcf sqltypes.TypeConversionFunc, nullableHelpers *core.NullableHelpers, col *plugin.Column) (*core.QueryReturn, error) { - strJavaType, err := tcf(col.Type) +func (gen *JavaGenerator) parseQueryReturn(col *plugin.Column) (*core.QueryReturn, error) { + isEnum := false + strJavaType, err := gen.typeConversionFunc(col.Type) if err != nil { - return nil, err + schema := col.Table.Schema + if schema == "" { + schema = gen.req.Catalog.DefaultSchema + } + + enumQualifiedName := fmt.Sprintf("%s.%s", schema, col.Type.Name) + if _, ok := gen.enums[enumQualifiedName]; !ok { + return nil, err + } + + gen.usedEnums = append(gen.usedEnums, enumQualifiedName) + strJavaType = gen.conf.Package + ".enums." + codegen.EnumClassName(enumQualifiedName, gen.req.Catalog.DefaultSchema) + isEnum = true } if col.ArrayDims > 1 { @@ -60,23 +122,24 @@ func parseQueryReturn(tcf sqltypes.TypeConversionFunc, nullableHelpers *core.Nul Type: strJavaType, IsList: col.IsArray, IsNullable: !col.NotNull, + IsEnum: isEnum, } if javaType.IsNullable { if javaType.IsList { - nullableHelpers.List = true + gen.nullableHelpers.List = true } else { switch strJavaType { case "Integer": - nullableHelpers.Int = true + gen.nullableHelpers.Int = true case "Long": - nullableHelpers.Long = true + gen.nullableHelpers.Long = true case "Float": - nullableHelpers.Float = true + gen.nullableHelpers.Float = true case "Double": - nullableHelpers.Double = true + gen.nullableHelpers.Double = true case "Boolean": - nullableHelpers.Boolean = true + gen.nullableHelpers.Boolean = true } } } @@ -87,39 +150,22 @@ func parseQueryReturn(tcf sqltypes.TypeConversionFunc, nullableHelpers *core.Nul }, nil } -func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) { - conf := core.Config{ - IndentChar: defaultIndentChar, - CharsPerIndentLevel: defaultCharsPerIndentLevel, - NullableAnnotation: "org.jspecify.annotations.Nullable", - NonNullAnnotation: "org.jspecify.annotations.NonNull", - } - if len(req.PluginOptions) > 0 { - if err := json.Unmarshal(req.PluginOptions, &conf); err != nil { - return nil, err +func (gen *JavaGenerator) Run() (*plugin.GenerateResponse, error) { + // parse out the enums from the generate request + for _, schema := range gen.req.Catalog.Schemas { + for _, enum := range schema.Enums { + gen.enums[fmt.Sprintf("%s.%s", schema.Name, enum.Name)] = core.Enum{ + Schema: schema.Name, + Name: enum.Name, + Values: enum.Vals, + } } } - if conf.Package == "" { - return nil, fmt.Errorf("'package' is a required configuration option") - } - - var typeConversionFunc sqltypes.TypeConversionFunc - switch req.Settings.Engine { - case "postgresql": - typeConversionFunc = sqltypes.PostgresTypeToJavaType - default: - return nil, fmt.Errorf("engine %q is not supported", req.Settings.Engine) - } - - var queries core.Queries = make(map[string][]core.Query) - var embeddedModels core.EmbeddedModels = make(map[string][]core.QueryReturn) - nullableHelpers := core.NullableHelpers{} - // parse the incoming generate request into our Queries type - for _, query := range req.Queries { - if _, ok := queries[query.Filename]; !ok { - queries[query.Filename] = make([]core.Query, 0) + for _, query := range gen.req.Queries { + if _, ok := gen.queries[query.Filename]; !ok { + gen.queries[query.Filename] = make([]core.Query, 0) } command, err := core.QueryCommandFor(query.Cmd) @@ -130,9 +176,23 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat // TODO - enum types? other specialness? args := make([]core.QueryArg, 0) for _, arg := range query.Params { - javaType, err := typeConversionFunc(arg.Column.Type) + isEnum := false + javaType, err := gen.typeConversionFunc(arg.Column.Type) if err != nil { - return nil, err + // check if this is an enum type + schema := arg.Column.Table.Schema + if schema == "" { + schema = gen.req.Catalog.DefaultSchema + } + + enumQualifiedName := fmt.Sprintf("%s.%s", schema, arg.Column.Type.Name) + if _, ok := gen.enums[enumQualifiedName]; !ok { + return nil, err + } + + gen.usedEnums = append(gen.usedEnums, enumQualifiedName) + javaType = gen.conf.Package + ".enums." + codegen.EnumClassName(enumQualifiedName, gen.req.Catalog.DefaultSchema) + isEnum = true } if arg.Column.ArrayDims > 1 { @@ -145,8 +205,9 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat JavaType: core.JavaType{ SqlType: sdk.DataType(arg.Column.Type), Type: javaType, - IsList: arg.Column.IsArray, // TODO check this will always be present + IsList: arg.Column.IsArray, IsNullable: !arg.Column.NotNull, + IsEnum: isEnum, }, }) } @@ -156,7 +217,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat for _, ret := range query.Columns { if ret.EmbedTable == nil { // normal types - qr, err := parseQueryReturn(typeConversionFunc, &nullableHelpers, ret) + qr, err := gen.parseQueryReturn(ret) if err != nil { return nil, errors.Join(errors.New("failed to parse query return column"), err) } @@ -169,12 +230,12 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat var table *plugin.Table // find the catalog entry for the embedded table - schema := req.Catalog.DefaultSchema + schema := gen.req.Catalog.DefaultSchema if ret.EmbedTable.Schema != "" { schema = ret.EmbedTable.Schema } - for _, s := range req.Catalog.Schemas { + for _, s := range gen.req.Catalog.Schemas { if s.Name != schema { continue } @@ -192,15 +253,15 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat // TODO - fix type-writer to only exclude items that aren't part of the package name modelName := strcase.ToCamel(table.Rel.Name) - if !conf.EmitExactTableNames { - modelName = strcase.ToCamel(inflection.Singular(table.Rel.Name, conf.InflectionExcludeTableNames)) + if !gen.conf.EmitExactTableNames { + modelName = strcase.ToCamel(inflection.Singular(table.Rel.Name, gen.conf.InflectionExcludeTableNames)) } // check if we already have an entry for this model - if _, ok := embeddedModels[modelName]; !ok { + if _, ok := gen.models[modelName]; !ok { var modelParams []core.QueryReturn for _, c := range table.Columns { - qr, err := parseQueryReturn(typeConversionFunc, &nullableHelpers, c) + qr, err := gen.parseQueryReturn(c) if err != nil { return nil, errors.Join(errors.New("failed to parse query return column"), err) } @@ -208,7 +269,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat modelParams = append(modelParams, *qr) } - embeddedModels[modelName] = modelParams + gen.models[modelName] = modelParams } returns = append(returns, core.QueryReturn{ @@ -216,7 +277,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat JavaType: core.JavaType{ SqlType: "", // we don't need to specify package here - models file will be generated in the same location as the queries file - Type: conf.Package + ".models." + modelName, + Type: gen.conf.Package + ".models." + modelName, IsList: false, // TODO - check: this *should* be impossible IsNullable: false, // TODO - check: empty record should be output instead }, @@ -225,12 +286,12 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat } // TODO - look into fixing ? operator for postgresql JSONB operations maybe - newQueryText, err := fixQueryPlaceholders(req.Settings.Engine, query.Text) + newQueryText, err := gen.fixQueryPlaceholders(query.Text) if err != nil { return nil, err } - queries[query.Filename] = append(queries[query.Filename], core.Query{ + gen.queries[query.Filename] = append(gen.queries[query.Filename], core.Query{ RawCommand: query.Cmd, Command: command, Text: newQueryText, @@ -243,12 +304,12 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat } outputFiles := make([]*plugin.File, 0) - // order the queries for each file alphabetically - for file := range queries { - slices.SortFunc(queries[file], func(a, b core.Query) int { return strings.Compare(a.MethodName, b.MethodName) }) + for file := range gen.queries { + // order the queries for each file alphabetically + slices.SortFunc(gen.queries[file], func(a, b core.Query) int { return strings.Compare(a.MethodName, b.MethodName) }) // build the queries file contents - fileName, fileContents, err := codegen.BuildQueriesFile(conf, file, queries[file], embeddedModels, nullableHelpers) + fileName, fileContents, err := codegen.BuildQueriesFile(gen.req.Settings.Engine, gen.conf, file, gen.queries[file], gen.models, gen.nullableHelpers) if err != nil { return nil, err } @@ -258,8 +319,27 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat }) } - for modelName, model := range embeddedModels { - fileName, fileContents, err := codegen.BuildModelFile(conf, modelName, model) + for modelName, model := range gen.models { + fileName, fileContents, err := codegen.BuildModelFile(gen.conf, modelName, model) + if err != nil { + return nil, err + } + outputFiles = append(outputFiles, &plugin.File{ + Name: fileName, + Contents: fileContents, + }) + } + + // remove duplicate enum entries + slices.Sort(gen.usedEnums) + slices.Compact(gen.usedEnums) + for _, qualName := range gen.usedEnums { + if qualName == "" { + continue + } + + enum := gen.enums[qualName] + fileName, fileContents, err := codegen.BuildEnumFile(gen.req.Settings.Engine, gen.conf, qualName, enum, gen.req.Catalog.DefaultSchema) if err != nil { return nil, err } @@ -271,3 +351,13 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat return &plugin.GenerateResponse{Files: outputFiles}, nil } + +// TODO - check if the context is actually important for anything +func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) { + jg, err := NewJavaGenerator(req) + if err != nil { + return nil, err + } + + return jg.Run() +} diff --git a/internal/sqltypes/mysql.go b/internal/sqltypes/mysql.go new file mode 100644 index 0000000..5b9b9a8 --- /dev/null +++ b/internal/sqltypes/mysql.go @@ -0,0 +1,40 @@ +package sqltypes + +import ( + "fmt" + "github.com/sqlc-dev/plugin-sdk-go/plugin" + "github.com/sqlc-dev/plugin-sdk-go/sdk" +) + +func MysqlTypeToJavaType(identifier *plugin.Identifier) (string, error) { + colType := sdk.DataType(identifier) + + switch colType { + case "varchar", "text", "char", "tinytext", "mediumtext", "longtext": + return "String", nil + case "int", "integer", "smallint", "mediumint", "year": + return "Integer", nil + case "bigint": + return "Long", nil + case "blob", "binary", "varbinary", "tinyblob", "mediumblob", "longblob": + return "byte[]", nil + case "double", "double precision", "real": + return "Double", nil + // TODO - why does gen-kotlin use string for this? look into a better solution + case "decimal", "dec", "fixed": + return "String", nil + case "date": + return "java.time.LocalDate", nil + case "datetime", "time": + return "java.time.LocalDateTime", nil + // TODO - instant support - look into option for this in pgsql as well + case "timestamp": + return "java.time.OffsetDateTime", nil + case "boolean", "bool", "tinyint": + return "Boolean", nil + case "json": + return "String", nil + default: + return "", fmt.Errorf("datatype '%s' not currently supported", colType) + } +} diff --git a/internal/sqltypes/postgresql.go b/internal/sqltypes/postgresql.go index c44d32f..042a5c8 100644 --- a/internal/sqltypes/postgresql.go +++ b/internal/sqltypes/postgresql.go @@ -24,9 +24,6 @@ func PostgresTypeToJavaType(identifier *plugin.Identifier) (string, error) { return "java.math.BigDecimal", nil case "bool", "pg_catalog.bool": return "Boolean", nil - // TODO - figure out if this can be supported properly - case "jsonb": - return "String", nil case "bytea", "blob", "pg_catalog.bytea": return "byte[]", nil case "date": @@ -41,10 +38,11 @@ func PostgresTypeToJavaType(identifier *plugin.Identifier) (string, error) { return "String", nil case "uuid": return "java.util.UUID", nil - case "inet", "void", "any": - return "", fmt.Errorf("datatype '%s' not currently supported", colType) + // TODO - figure out if these can be supported properly + case "jsonb", "inet": + return "String", nil default: - // TODO - deal with enums somehow + // void, any return "", fmt.Errorf("datatype '%s' not currently supported", colType) } } diff --git a/tests/pom.xml b/tests/pom.xml index 2c74b14..4b1aff8 100644 --- a/tests/pom.xml +++ b/tests/pom.xml @@ -27,16 +27,23 @@ + org.jspecify jspecify 1.0.0 + org.postgresql postgresql 42.7.5 + + com.mysql + mysql-connector-j + 9.2.0 + org.junit.jupiter @@ -49,6 +56,7 @@ assertj-core 3.27.3 + org.testcontainers testcontainers @@ -64,6 +72,11 @@ postgresql test + + org.testcontainers + mysql + test + diff --git a/tests/sqlc.yaml b/tests/sqlc.yaml index 1431440..e045e6c 100644 --- a/tests/sqlc.yaml +++ b/tests/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: java wasm: url: file://sqlc-gen-java.wasm - sha256: a2211754142fd64604a7b5b2cbec985f685830ca07db860974a7996f6c06fb5b + sha256: 1e30725cfd253e18486ebff9ba7eea481bc2680d946ecae85d308de39485b921 sql: - schema: src/main/resources/postgres/schema.sql queries: src/main/resources/postgres/queries.sql @@ -13,3 +13,11 @@ sql: plugin: java options: package: io.github.tandemdude.sgj.postgres + - schema: src/main/resources/mysql/schema.sql + queries: src/main/resources/mysql/queries.sql + engine: mysql + codegen: + - out: src/main/java/io/github/tandemdude/sgj/mysql + plugin: java + options: + package: io.github.tandemdude.sgj.mysql diff --git a/tests/src/main/resources/mysql/queries.sql b/tests/src/main/resources/mysql/queries.sql new file mode 100644 index 0000000..b235fa1 --- /dev/null +++ b/tests/src/main/resources/mysql/queries.sql @@ -0,0 +1,60 @@ +-- name: GetAuthor :one +SELECT * FROM authors +WHERE author_id = ?; + +-- name: GetBook :one +SELECT * FROM books +WHERE book_id = ?; + +-- name: DeleteBook :exec +DELETE FROM books +WHERE book_id = ?; + +-- name: BooksByTitleYear :many +SELECT * FROM books +WHERE title = ? AND yr = ?; + +-- name: BooksByTags :many +SELECT + book_id, + title, + name, + isbn, + tags +FROM books + LEFT JOIN authors ON books.author_id = authors.author_id +WHERE tags = ?; + +-- name: CreateAuthor :execresult +INSERT INTO authors (name) VALUES (?); + +-- name: CreateBook :execresult +INSERT INTO books ( + author_id, + isbn, + book_type, + title, + yr, + available, + tags +) VALUES (?, ?, ?, ?, ?, ?, ?); + +-- name: UpdateBook :exec +UPDATE books +SET title = ?, tags = ? +WHERE book_id = ?; + +-- name: UpdateBookISBN :exec +UPDATE books +SET title = ?, tags = ?, isbn = ? +WHERE book_id = ?; + +-- name: DeleteAuthorBeforeYear :exec +DELETE FROM books +WHERE yr < ? AND author_id = ?; + +-- name: CreateEnumRow :execresult +INSERT INTO nullable_enum_test (enum_field) VALUES (?); + +-- name: GetEnumRow :one +SELECT * FROM nullable_enum_test WHERE t_id = ?; \ No newline at end of file diff --git a/tests/src/main/resources/mysql/schema.sql b/tests/src/main/resources/mysql/schema.sql new file mode 100644 index 0000000..ed6b2aa --- /dev/null +++ b/tests/src/main/resources/mysql/schema.sql @@ -0,0 +1,26 @@ +-- taken from sqlc-gen-kotlin, booktest -- +CREATE TABLE authors ( + author_id integer NOT NULL AUTO_INCREMENT PRIMARY KEY, + name text NOT NULL +) ENGINE=InnoDB; + +CREATE INDEX authors_name_idx ON authors(name(255)); + +CREATE TABLE books ( + book_id integer NOT NULL AUTO_INCREMENT PRIMARY KEY, + author_id integer NOT NULL, + isbn varchar(255) NOT NULL DEFAULT '' UNIQUE, + book_type ENUM('FICTION', 'NONFICTION') NOT NULL DEFAULT 'FICTION', + title text NOT NULL, + yr integer NOT NULL DEFAULT 2000, + available datetime NOT NULL DEFAULT NOW(), + tags text NOT NULL +) ENGINE=InnoDB; + +CREATE INDEX books_title_idx ON books(title(255), yr); +-- end -- + +CREATE TABLE nullable_enum_test ( + t_id integer NOT NULL AUTO_INCREMENT PRIMARY KEY, + enum_field ENUM('foo', 'bar') DEFAULT NULL +); diff --git a/tests/src/main/resources/postgres/queries.sql b/tests/src/main/resources/postgres/queries.sql index 6fa0a75..36dc980 100644 --- a/tests/src/main/resources/postgres/queries.sql +++ b/tests/src/main/resources/postgres/queries.sql @@ -49,3 +49,9 @@ RETURNING row_id; -- name: GetBytes :one SELECT * FROM bytes WHERE row_id = $1; + +-- name: CreatePerson :exec +INSERT INTO person(name, current_mood, next_mood) VALUES ($1, $2, $3); + +-- name: GetPerson :one +SELECT * FROM person WHERE name = $1; diff --git a/tests/src/main/resources/postgres/schema.sql b/tests/src/main/resources/postgres/schema.sql index c7c0d4d..663f1d4 100644 --- a/tests/src/main/resources/postgres/schema.sql +++ b/tests/src/main/resources/postgres/schema.sql @@ -38,3 +38,10 @@ CREATE TABLE bytes ( contents BYTEA NOT NULL, hash BYTEA DEFAULT NULL ); + +CREATE TYPE mood AS ENUM ('sad', 'happy', 'ok'); +CREATE TABLE person ( + name TEXT PRIMARY KEY, + current_mood mood NOT NULL, + next_mood mood DEFAULT NULL +); diff --git a/tests/src/test/java/io/github/tandemdude/sgj/mysql/TestQueries.java b/tests/src/test/java/io/github/tandemdude/sgj/mysql/TestQueries.java new file mode 100644 index 0000000..c629bb7 --- /dev/null +++ b/tests/src/test/java/io/github/tandemdude/sgj/mysql/TestQueries.java @@ -0,0 +1,63 @@ +package io.github.tandemdude.sgj.mysql; + +import io.github.tandemdude.sgj.mysql.enums.BooksBookType; +import io.github.tandemdude.sgj.mysql.enums.NullableEnumTestEnumField; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.time.LocalDateTime; + +import static org.assertj.core.api.Assertions.assertThat; + +@Testcontainers +public class TestQueries { + @Container + private final MySQLContainer mysql = new MySQLContainer<>("mysql:latest") + .withInitScript("mysql/schema.sql"); + + Connection getConn() throws SQLException { + var conn = DriverManager.getConnection(mysql.getJdbcUrl(), mysql.getUsername(), mysql.getPassword()); + conn.setAutoCommit(true); + return conn; + } + + @Test + @DisplayName("enum types can be read and written") + void enumTypesCanBeReadAndWritten() throws Exception { + try (var conn = getConn()) { + var q = new Queries(conn); + + var authorId = q.createAuthor("foo"); + var bookId = q.createBook((int) authorId, "foo", BooksBookType.FICTION, "bar", 2000, LocalDateTime.now(), "baz"); + + var foundRow = q.getBook((int) bookId); + assertThat(foundRow).isPresent(); + var found = foundRow.get(); + assertThat(found.bookType()).isEqualTo(BooksBookType.FICTION); + } + } + + @Test + @DisplayName("nullable enum types can be read and written") + void nullableEnumTypesCanBeReadAndWritten() throws Exception { + try (var conn = getConn()) { + var q = new Queries(conn); + + var id = q.createEnumRow(NullableEnumTestEnumField.BAR); + var foundRow = q.getEnumRow((int) id); + assertThat(foundRow).isPresent(); + assertThat(foundRow.get().enumField()).isEqualTo(NullableEnumTestEnumField.BAR); + + var id2 = q.createEnumRow(null); + var foundRow2 = q.getEnumRow((int) id2); + assertThat(foundRow2).isPresent(); + assertThat(foundRow2.get().enumField()).isNull(); + } + } +} diff --git a/tests/src/test/java/io/github/tandemdude/sgj/postgres/TestQueries.java b/tests/src/test/java/io/github/tandemdude/sgj/postgres/TestQueries.java index 88d7401..07d4a1c 100644 --- a/tests/src/test/java/io/github/tandemdude/sgj/postgres/TestQueries.java +++ b/tests/src/test/java/io/github/tandemdude/sgj/postgres/TestQueries.java @@ -1,6 +1,6 @@ package io.github.tandemdude.sgj.postgres; -import org.jspecify.annotations.NonNull; +import io.github.tandemdude.sgj.postgres.enums.Mood; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.DisplayName; import org.testcontainers.containers.PostgreSQLContainer; @@ -32,7 +32,7 @@ Connection getConn() throws SQLException { @Test @DisplayName("GetUser returns empty optional when no records found") - public void getUserReturnsEmptyOptionalNoRecordsFound() throws Exception { + void getUserReturnsEmptyOptionalNoRecordsFound() throws Exception { try (var conn = getConn()) { var q = new Queries(conn); @@ -42,7 +42,7 @@ public void getUserReturnsEmptyOptionalNoRecordsFound() throws Exception { @Test @DisplayName("GetUser returns populated optional when record found") - public void getUserReturnsPopulatedOptionalRecordFound() throws Exception { + void getUserReturnsPopulatedOptionalRecordFound() throws Exception { try (var conn = getConn()) { var q = new Queries(conn); @@ -59,7 +59,7 @@ public void getUserReturnsPopulatedOptionalRecordFound() throws Exception { @Test @DisplayName("ListUsers returns empty list when no records found") - public void listUsersReturnsEmptyListNoRecordsFound() throws Exception { + void listUsersReturnsEmptyListNoRecordsFound() throws Exception { try (var conn = getConn()) { var q = new Queries(conn); @@ -69,7 +69,7 @@ public void listUsersReturnsEmptyListNoRecordsFound() throws Exception { @Test @DisplayName("ListUsers returns populated list when records found") - public void listUsersReturnsPopulatedListRecordsFound() throws Exception { + void listUsersReturnsPopulatedListRecordsFound() throws Exception { try (var conn = getConn()) { var q = new Queries(conn); @@ -84,7 +84,7 @@ public void listUsersReturnsPopulatedListRecordsFound() throws Exception { @Test @DisplayName("GetUserDup throws error when multiple records returned") - public void getUserDupReturnsErrorWhenMultipleRecordsReturned() throws Exception { + void getUserDupReturnsErrorWhenMultipleRecordsReturned() throws Exception { try (var conn = getConn()) { var q = new Queries(conn); @@ -99,7 +99,7 @@ public void getUserDupReturnsErrorWhenMultipleRecordsReturned() throws Exception @Test @DisplayName("CreateMessage processes input list correctly") - public void createMessageProcessesInputListCorrectly() throws Exception { + void createMessageProcessesInputListCorrectly() throws Exception { try (var conn = getConn()) { var q = new Queries(conn); @@ -114,7 +114,7 @@ public void createMessageProcessesInputListCorrectly() throws Exception { @Test @DisplayName("GetMessage works when attachments is null") - public void getMessageWorksWhenAttachmentsIsNull() throws Exception { + void getMessageWorksWhenAttachmentsIsNull() throws Exception { try (var conn = getConn()) { var q = new Queries(conn); @@ -129,7 +129,7 @@ public void getMessageWorksWhenAttachmentsIsNull() throws Exception { @Test @DisplayName("GetUserAndToken returns embedded objects") - public void getUserAndTokenReturnsEmbeddedObjects() throws Exception { + void getUserAndTokenReturnsEmbeddedObjects() throws Exception { try (var conn = getConn()) { var q = new Queries(conn); @@ -146,7 +146,7 @@ public void getUserAndTokenReturnsEmbeddedObjects() throws Exception { @Test @DisplayName("GetBytes returns same data as during creation") - public void getBytesReturnsSameDataAsDuringCreation() throws Exception { + void getBytesReturnsSameDataAsDuringCreation() throws Exception { try (var conn = getConn()) { var q = new Queries(conn); @@ -169,4 +169,24 @@ public void getBytesReturnsSameDataAsDuringCreation() throws Exception { assertThat(found2.get().hash()).isEqualTo(s2); } } + + @Test + @DisplayName("nullable and non-nullable enums work correctly") + void nullableAndNonNullableEnumsWorkCorrectly() throws Exception { + try (var conn = getConn()) { + var q = new Queries(conn); + + q.createPerson("foo", Mood.HAPPY, null); + var p1 = q.getPerson("foo"); + assertThat(p1).isPresent(); + assertThat(p1.get().currentMood()).isEqualTo(Mood.HAPPY); + assertThat(p1.get().nextMood()).isNull(); + + q.createPerson("bar", Mood.HAPPY, Mood.OK); + var p2 = q.getPerson("bar"); + assertThat(p2).isPresent(); + assertThat(p2.get().currentMood()).isEqualTo(Mood.HAPPY); + assertThat(p2.get().nextMood()).isEqualTo(Mood.OK); + } + } }