diff --git a/.gitignore b/.gitignore index 97f1562..e37a153 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.wasm .DS_Store -.idea/ \ No newline at end of file +.idea/ +main.go diff --git a/go.mod b/go.mod index 4c8aa4f..305bfcc 100644 --- a/go.mod +++ b/go.mod @@ -1,17 +1,33 @@ module github.com/tandemdude/sqlc-gen-java -go 1.23 +go 1.24.5 -require github.com/sqlc-dev/plugin-sdk-go v1.23.0 +require ( + github.com/sqlc-dev/plugin-sdk-go v1.23.0 + github.com/stretchr/testify v1.10.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/go-cmp v0.6.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/mod v0.24.0 // indirect + golang.org/x/sync v0.13.0 // indirect + golang.org/x/tools v0.32.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + mvdan.cc/gofumpt v0.8.0 // indirect +) require ( github.com/golang/protobuf v1.5.3 // indirect github.com/iancoleman/strcase v0.3.0 github.com/jinzhu/inflection v1.0.0 - golang.org/x/net v0.14.0 // indirect - golang.org/x/sys v0.11.0 // indirect - golang.org/x/text v0.12.0 // indirect + golang.org/x/net v0.39.0 // indirect + golang.org/x/sys v0.32.0 // indirect + golang.org/x/text v0.24.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d // indirect google.golang.org/grpc v1.59.0 // indirect google.golang.org/protobuf v1.31.0 // indirect ) + +tool mvdan.cc/gofumpt diff --git a/go.sum b/go.sum index a86813d..f8672d6 100644 --- a/go.sum +++ b/go.sum @@ -1,21 +1,41 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI= +github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/sqlc-dev/plugin-sdk-go v1.23.0 h1:iSeJhnXPlbDXlbzUEebw/DxsGzE9rdDJArl8Hvt0RMM= github.com/sqlc-dev/plugin-sdk-go v1.23.0/go.mod h1:I1r4THOfyETD+LI2gogN2LX8wCjwUZrgy/NU4In3llA= -golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= -golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= -golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= -golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= -golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= +golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +golang.org/x/tools v0.32.0 h1:Q7N1vhpkQv7ybVzLFtTjvQya2ewbwNDZzUgfXGqtMWU= +golang.org/x/tools v0.32.0/go.mod h1:ZxrU41P/wAbZD8EDa6dDCa6XfpkhJ7HFMjHJXfBDu8s= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d h1:uvYuEyMHKNt+lT4K3bN6fGswmK8qSvcreM3BwjDh+y4= google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d/go.mod h1:+Bk1OCOj40wS2hwAMA+aCW9ypzm63QTBBHp6lQ3p+9M= @@ -25,3 +45,9 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +mvdan.cc/gofumpt v0.8.0 h1:nZUCeC2ViFaerTcYKstMmfysj6uhQrA2vJe+2vwGU6k= +mvdan.cc/gofumpt v0.8.0/go.mod h1:vEYnSzyGPmjvFkqJWtXkh79UwPWP9/HMxQdGEXZHjpg= diff --git a/internal/codegen/enums.go b/internal/codegen/enums.go index 9846212..877cf04 100644 --- a/internal/codegen/enums.go +++ b/internal/codegen/enums.go @@ -2,12 +2,13 @@ package codegen import ( "fmt" - "github.com/iancoleman/strcase" - "github.com/tandemdude/sqlc-gen-java/internal/core" "regexp" "strings" "unicode" "unicode/utf8" + + "github.com/iancoleman/strcase" + "github.com/tandemdude/sqlc-gen-java/internal/core" ) var javaInvalidIdentChars = regexp.MustCompile("[^$\\w]") diff --git a/internal/core/constants.go b/internal/core/constants.go index f5f56aa..dea5de9 100644 --- a/internal/core/constants.go +++ b/internal/core/constants.go @@ -1,3 +1,15 @@ package core +import ( + "os" + "strings" +) + const PluginVersion = "0.0.6" + +var FileHeaderComment = strings.Join([]string{ + "// Code generated by sqlc. DO NOT EDIT.", + "// versions:", + "// sqlc " + os.Getenv("SQLC_VERSION"), + "// sqlc-gen-java " + PluginVersion, +}, "\n") diff --git a/internal/inflection/singular.go b/internal/inflection/singular.go index 3af8d45..904cc78 100644 --- a/internal/inflection/singular.go +++ b/internal/inflection/singular.go @@ -1,9 +1,10 @@ package inflection import ( - "github.com/jinzhu/inflection" "slices" "strings" + + "github.com/jinzhu/inflection" ) func Singular(s string, excludes []string) string { diff --git a/internal/sqltypes/mysql.go b/internal/sqltypes/mysql.go index 0586b38..2c030d6 100644 --- a/internal/sqltypes/mysql.go +++ b/internal/sqltypes/mysql.go @@ -3,6 +3,8 @@ package sqltypes import ( "fmt" + "github.com/tandemdude/sqlc-gen-java/poet" + "github.com/sqlc-dev/plugin-sdk-go/plugin" "github.com/sqlc-dev/plugin-sdk-go/sdk" ) @@ -38,3 +40,35 @@ func MysqlTypeToJavaType(identifier *plugin.Identifier) (string, error) { return "", fmt.Errorf("datatype '%s' not currently supported", colType) } } + +func ConvertMySQLType(identifier *plugin.Identifier) (poet.TypeName, error) { + colType := sdk.DataType(identifier) + + switch colType { + case "varchar", "text", "char", "tinytext", "mediumtext", "longtext": + return poet.String, nil + case "int", "integer", "smallint", "mediumint", "year": + return poet.IntBoxed, nil + case "bigint": + return poet.LongBoxed, nil + case "blob", "binary", "varbinary", "tinyblob", "mediumblob", "longblob": + return poet.Byte.Array(), nil + case "double", "double precision", "real": + return poet.DoubleBoxed, nil + case "decimal", "dec", "fixed": + return poet.NewClassName("java.math", "BigDecimal"), nil + case "date": + return poet.NewClassName("java.time", "LocalTime"), nil + case "datetime", "time": + return poet.NewClassName("java.time", "LocalDateTime"), nil + // TODO - instant support - look into option for this in pgsql as well + case "timestamp": + return poet.NewClassName("java.time", "OffsetDateTime"), nil + case "boolean", "bool", "tinyint": + return poet.BoolBoxed, nil + case "json": + return poet.String, nil + default: + return poet.TypeName{}, fmt.Errorf("datatype '%s' not currently supported", colType) + } +} diff --git a/internal/sqltypes/postgresql.go b/internal/sqltypes/postgresql.go index 042a5c8..92b95dd 100644 --- a/internal/sqltypes/postgresql.go +++ b/internal/sqltypes/postgresql.go @@ -2,6 +2,9 @@ package sqltypes import ( "fmt" + + "github.com/tandemdude/sqlc-gen-java/poet" + "github.com/sqlc-dev/plugin-sdk-go/plugin" "github.com/sqlc-dev/plugin-sdk-go/sdk" ) @@ -46,3 +49,44 @@ func PostgresTypeToJavaType(identifier *plugin.Identifier) (string, error) { return "", fmt.Errorf("datatype '%s' not currently supported", colType) } } + +func ConvertPostgresType(identifier *plugin.Identifier) (poet.TypeName, error) { + colType := sdk.DataType(identifier) + + switch colType { + case "serial", "pg_catalog.serial4", "integer", "int", "int4", "pg_catalog.int4": + return poet.IntBoxed, nil + case "bigserial", "pg_catalog.serial8", "bigint", "pg_catalog.int8": + return poet.LongBoxed, nil + case "smallserial", "pg_catalog.serial2", "smallint", "pg_catalog.int2": + return poet.ShortBoxed, nil + case "float", "double precision", "pg_catalog.float8": + return poet.DoubleBoxed, nil + case "real", "pg_catalog.float4": + return poet.FloatBoxed, nil + case "pg_catalog.numeric": + return poet.NewClassName("java.math", "BigDecimal"), nil + case "bool", "pg_catalog.bool": + return poet.BoolBoxed, nil + case "bytea", "blob", "pg_catalog.bytea": + return poet.Byte.Array(), nil + case "date": + return poet.NewClassName("java.time", "LocalDate"), nil + case "pg_catalog.time", "pg_catalog.timetz": + return poet.NewClassName("java.time", "LocalTime"), nil + case "pg_catalog.timestamp", "timestamp": + return poet.NewClassName("java.time", "LocalDateTime"), nil + case "pg_catalog.timestamptz", "timestamptz": + return poet.NewClassName("java.time", "OffsetDateTime"), nil + case "text", "pg_catalog.varchar", "pg_catalog.bpchar", "string": + return poet.String, nil + case "uuid": + return poet.NewClassName("java.util", "UUID"), nil + // TODO - figure out if these can be supported properly + case "jsonb", "inet": + return poet.String, nil + default: + // void, any + return poet.TypeName{}, fmt.Errorf("datatype '%s' not currently supported", colType) + } +} diff --git a/poet/code.go b/poet/code.go new file mode 100644 index 0000000..b150e20 --- /dev/null +++ b/poet/code.go @@ -0,0 +1,175 @@ +package poet + +import ( + "fmt" + "regexp" + "strings" +) + +const ( + replaceTypeLiteral rune = 'L' + replaceTypeString rune = 'S' + replaceTypeType rune = 'T' +) + +var stringFormatRegex = regexp.MustCompile(`\${1,2}\d*[LST]`) + +type Code struct { + RawCode string + IsFlow bool + IsTryCatch bool + IsIfElse bool + + Arguments []any + Statements []Code +} + +func stringify(raw any) string { + if raw == nil { + return "null" + } + return fmt.Sprintf("%v", raw) +} + +func formatRawCode(ctx *Context, rawCode string, arguments []any) string { + matchIndex := 0 + + return stringFormatRegex.ReplaceAllStringFunc(rawCode, func(match string) string { + // if the pattern is escaped + if strings.HasPrefix(match, "$$") { + return match[1:] + } + + argumentIndex, replaceType := 0, rune(match[len(match)-1]) + for i := 1; i < len(match)-1; i++ { + argumentIndex = (argumentIndex * 10) + int(match[i]-'0') + } + + argumentIndex -= 1 + if argumentIndex < 0 { + argumentIndex = matchIndex + } + + if argumentIndex > len(arguments)-1 { + // tried to access an argument that is not there - TODO allow errors to be returned from formatting + return match + } + + replacement := match + switch replaceType { + case replaceTypeLiteral: + replacement = stringify(arguments[argumentIndex]) + case replaceTypeString: + replacement = fmt.Sprintf("%q", stringify(arguments[argumentIndex])) + case replaceTypeType: + // it is unlikely that a user will ever want to include the generic constraint in the formatted + // value - in the future could make this configurable through extended replace type codes + replacement = arguments[argumentIndex].(TypeName).Format(ctx, ExcludeConstraints) + } + + if len(match) == 2 { + matchIndex += 1 + } + + return replacement + }) +} + +func shouldInlineNextStmt(stmt Code, nextStmt Code) bool { + if !nextStmt.IsFlow { + return false + } + + if stmt.IsIfElse { + return nextStmt.IsIfElse && !strings.HasPrefix(nextStmt.RawCode, "if") + } else if stmt.IsTryCatch { + return nextStmt.IsTryCatch && !strings.HasPrefix(nextStmt.RawCode, "try") + } + + return false +} + +func formatStatements(ctx *Context, statements []Code) string { + var sb strings.Builder + + for i, stmt := range statements { + sb.WriteString(stmt.Format(ctx)) + + if stmt.IsFlow && (len(statements) > i+1) && shouldInlineNextStmt(stmt, statements[i+1]) { + sb.WriteString(" ") + } else { + sb.WriteString("\n") + } + } + + return sb.String() +} + +func (c *Code) Format(ctx *Context) string { + var sb strings.Builder + + if c.IsFlow { + // Control flow statement + sb.WriteString(formatRawCode(ctx, c.RawCode, c.Arguments)) + sb.WriteString(" {\n") + sb.WriteString(ctx.indent(formatStatements(ctx, c.Statements))) + sb.WriteString("}") + + return sb.String() + } + + if c.RawCode != "" && !c.IsFlow { + // Simple statement + sb.WriteString(formatRawCode(ctx, c.RawCode, c.Arguments)) + if !strings.HasSuffix(c.RawCode, ";") { + sb.WriteRune(';') + } + + return sb.String() + } + + // List of statements + sb.WriteString(formatStatements(ctx, c.Statements)) + + return sb.String() +} + +type CodeBuilder struct { + code Code +} + +func NewCodeBuilder() *CodeBuilder { + return &CodeBuilder{} +} + +func (b *CodeBuilder) WithStatement(stmt string, args ...any) *CodeBuilder { + b.code.Statements = append(b.code.Statements, Code{RawCode: stmt, Arguments: args}) + return b +} + +func (b *CodeBuilder) WithControlFlow(stmt string, blockBuilderFn func(*CodeBuilder), args ...any) *CodeBuilder { + builder := NewCodeBuilder() + builder.code.Arguments = args + builder.code.RawCode = stmt + builder.code.IsFlow = true + + if strings.HasPrefix(stmt, "if") || strings.HasPrefix(stmt, "else") { + builder.code.IsIfElse = true + } else if strings.HasPrefix(stmt, "try") || strings.HasPrefix(stmt, "catch") || strings.HasPrefix(stmt, "finally") { + builder.code.IsTryCatch = true + } + + blockBuilderFn(builder) + + b.code.Statements = append(b.code.Statements, builder.Build()) + return b +} + +func (b *CodeBuilder) WithRawCode(code string) *CodeBuilder { + b.code.Statements = append(b.code.Statements, Code{RawCode: code}) + return b +} + +func (b *CodeBuilder) Build() Code { + return b.code +} diff --git a/poet/code_test.go b/poet/code_test.go new file mode 100644 index 0000000..821b517 --- /dev/null +++ b/poet/code_test.go @@ -0,0 +1,104 @@ +package poet + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCodeBuilder_StringFormatting(t *testing.T) { + tests := []struct { + name string + code string + expected string + args []any + }{ + { + name: "simple replacement", + code: "$T $L = $S", + expected: `String test = "a value"`, + args: []any{String, "test", "a value"}, + }, + { + name: "complex replacement", + code: "$T $3L = $2S + $3S", + expected: `String test = "a value" + "test"`, + args: []any{String, "a value", "test"}, + }, + { + name: "nil values", + code: "$1L $1S", + expected: `null "null"`, + args: []any{nil}, + }, + { + name: "bool values", + code: "$1L $1S", + expected: `true "true"`, + args: []any{true}, + }, + { + name: "int values", + code: "$1L $1S", + expected: `1 "1"`, + args: []any{1}, + }, + { + name: "float values", + code: "$1L $1S", + expected: `1.1 "1.1"`, + args: []any{1.1}, + }, + { + name: "string values", + code: "$1L $1S", + expected: `hello "hello"`, + args: []any{"hello"}, + }, + { + name: "quoted strings", + code: "$1L $1S", + expected: `"hello" "\"hello\""`, + args: []any{`"hello"`}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + + ctx := NewContext("io.github.tandemdude") + + code := NewCodeBuilder(). + WithStatement(tt.code, tt.args...). + Build() + + assert.Equal(tt.expected+";\n", code.Format(ctx)) + }) + } +} + +func TestCodeBuilder_MultipleStatements(t *testing.T) { + assert := assert.New(t) + + ctx := NewContext("io.github.tandemdude") + + code := NewCodeBuilder(). + WithStatement("$L $S", false, true). + WithStatement("$L $S", true, false). + Build() + + assert.Equal("false \"true\";\ntrue \"false\";\n", code.Format(ctx)) +} + +func TestCodeBuilder_NoStatements(t *testing.T) { + assert := assert.New(t) + + ctx := NewContext("io.github.tandemdude") + + code := NewCodeBuilder().Build() + + assert.Equal("", code.Format(ctx)) +} diff --git a/poet/context.go b/poet/context.go new file mode 100644 index 0000000..05f78c4 --- /dev/null +++ b/poet/context.go @@ -0,0 +1,74 @@ +package poet + +import ( + "context" + "slices" + "strings" +) + +type Context struct { + context.Context + + CurrentPackage string + CurrentTypeName string + + Imports []string + Types map[string]TypeName + + Indent string +} + +func WithIndent(indent string) func(*Context) { + return func(ctx *Context) { + ctx.Indent = indent + } +} + +func NewContextFromContext(ctx context.Context, currentPackage string, options ...func(*Context)) *Context { + newContext := &Context{ + Context: ctx, + CurrentPackage: currentPackage, + Types: make(map[string]TypeName), + } + + for _, option := range options { + option(newContext) + } + + return newContext +} + +func NewContext(currentPackage string, options ...func(*Context)) *Context { + return NewContextFromContext(context.Background(), currentPackage, options...) +} + +func (ctx *Context) Import(imports ...string) { + // we don't need to import anything that is available within the current package + for _, import_ := range imports { + if import_ == ctx.CurrentPackage { + continue + } + + ctx.Imports = append(ctx.Imports, import_) + } + + slices.Sort(ctx.Imports) + ctx.Imports = slices.Compact(ctx.Imports) +} + +func (ctx *Context) indent(text string) string { + if len(strings.TrimSpace(text)) == 0 { + return "" + } + + lines := strings.Split(text, "\n") + for i, line := range lines { + if len(strings.TrimSpace(line)) == 0 { + lines[i] = "" + } else { + lines[i] = ctx.Indent + line + } + } + + return strings.Join(lines, "\n") +} diff --git a/poet/file.go b/poet/file.go new file mode 100644 index 0000000..2703246 --- /dev/null +++ b/poet/file.go @@ -0,0 +1,51 @@ +package poet + +import ( + "strings" +) + +type formattable interface { + name() string + Format(ctx *Context) string +} + +type FileOptions struct { + Comment string +} + +func WithFileComment(comment string) func(*FileOptions) { + return func(o *FileOptions) { + o.Comment = comment + } +} + +func FormatFile(ctx *Context, member formattable, options ...func(*FileOptions)) string { + ctx.CurrentTypeName = member.name() + + opts := &FileOptions{} + for _, o := range options { + o(opts) + } + + var sb strings.Builder + + if opts.Comment != "" { + sb.WriteString(opts.Comment) + if opts.Comment[len(opts.Comment)-1] != '\n' { + sb.WriteString("\n") + } + } + + sb.WriteString("package " + ctx.CurrentPackage + ";\n\n") + + memberString := member.Format(ctx) + for i, imp := range ctx.Imports { + sb.WriteString("import " + imp + ";\n") + if i == len(ctx.Imports)-1 { + sb.WriteString("\n") + } + } + + sb.WriteString(memberString) + return sb.String() +} diff --git a/poet/method.go b/poet/method.go new file mode 100644 index 0000000..30b9997 --- /dev/null +++ b/poet/method.go @@ -0,0 +1,158 @@ +package poet + +import ( + "slices" + "strings" +) + +type MethodParameter struct { + Name string + Type TypeName +} + +func NewMethodParam(name string, typ TypeName) MethodParameter { + return MethodParameter{Name: name, Type: typ} +} + +func (m MethodParameter) Format(ctx *Context) string { + return m.Type.Format(ctx, ExcludeConstraints) + " " + m.Name +} + +type Method struct { + Name string + + Modifiers []Modifier + GenericParameters []TypeName + Parameters []MethodParameter + ReturnType TypeName + Throws []TypeName + + Body Code + + isConstructor bool +} + +func (m Method) Format(ctx *Context) string { + var sb strings.Builder + + sb.WriteString(formatModifiers(m.Modifiers)) + sb.WriteString(" ") + writeGenericParamList(ctx, &sb, m.GenericParameters, true) + if !m.isConstructor { + sb.WriteString(m.ReturnType.Format(ctx, ExcludeConstraints)) + sb.WriteString(" ") + } + + sb.WriteString(m.Name) + + sb.WriteString("(") + for i, param := range m.Parameters { + sb.WriteString(param.Format(ctx)) + if i < len(m.Parameters)-1 { + sb.WriteString(", ") + } + } + + sb.WriteString(")") + if len(m.Throws) > 0 { + sb.WriteString(" throws ") + for i, throw := range m.Throws { + sb.WriteString(throw.Format(ctx, ExcludeConstraints, ExcludeParameters, ExcludeArrayBraces)) + + if i < len(m.Throws)-1 { + sb.WriteString(",") + } + } + } + + // abstract methods cannot have bodies + if slices.Contains(m.Modifiers, ModifierAbstract) { + sb.WriteString(";") + return sb.String() + } + + if code := m.Body.Format(ctx); strings.TrimSpace(code) != "" { + sb.WriteString(" {\n") + sb.WriteString(ctx.indent(code)) + sb.WriteString("}") + } else { + sb.WriteString(" {}") + } + + return sb.String() +} + +type MethodBuilder struct { + method Method +} + +func NewMethodBuilder(name string, returnType TypeName) *MethodBuilder { + return &MethodBuilder{method: Method{Name: name, ReturnType: returnType}} +} + +func (b *MethodBuilder) WithModifiers(modifiers ...Modifier) *MethodBuilder { + b.method.Modifiers = appendModifiers(b.method.Modifiers, modifiers) + return b +} + +func (b *MethodBuilder) WithGenericParameters(parameters ...TypeName) *MethodBuilder { + b.method.GenericParameters = append(b.method.GenericParameters, parameters...) + return b +} + +func (b *MethodBuilder) WithParameters(params ...MethodParameter) *MethodBuilder { + b.method.Parameters = append(b.method.Parameters, params...) + return b +} + +func (b *MethodBuilder) WithThrows(throws ...TypeName) *MethodBuilder { + b.method.Throws = append(b.method.Throws, throws...) + return b +} + +func (b *MethodBuilder) WithCode(code Code) *MethodBuilder { + b.method.Body = code + return b +} + +func (b *MethodBuilder) Build() Method { + b.method.Modifiers = maybeSetPackagePrivate(b.method.Modifiers) + return b.method +} + +type Constructor struct { + Method +} + +type ConstructorBuilder struct { + constructor Constructor +} + +func NewConstructorBuilder() *ConstructorBuilder { + return &ConstructorBuilder{constructor: Constructor{Method{isConstructor: true}}} +} + +func (b *ConstructorBuilder) WithModifiers(modifiers ...Modifier) *ConstructorBuilder { + b.constructor.Modifiers = appendModifiers(b.constructor.Modifiers, modifiers) + return b +} + +func (b *ConstructorBuilder) WithParameters(params ...MethodParameter) *ConstructorBuilder { + b.constructor.Parameters = append(b.constructor.Parameters, params...) + return b +} + +func (b *ConstructorBuilder) WithThrows(throws ...TypeName) *ConstructorBuilder { + b.constructor.Throws = append(b.constructor.Throws, throws...) + return b +} + +func (b *ConstructorBuilder) WithCode(code Code) *ConstructorBuilder { + b.constructor.Body = code + return b +} + +func (b *ConstructorBuilder) Build() Constructor { + b.constructor.Modifiers = maybeSetPackagePrivate(b.constructor.Modifiers) + return b.constructor +} diff --git a/poet/modifier.go b/poet/modifier.go new file mode 100644 index 0000000..dcab351 --- /dev/null +++ b/poet/modifier.go @@ -0,0 +1,61 @@ +package poet + +import ( + "strings" +) + +type Modifier int + +const ( + ModifierPrivate Modifier = iota + ModifierPackagePrivate + ModifierProtected + ModifierPublic + ModifierAbstract + ModifierStatic + ModifierFinal +) + +var accessModifiers = []Modifier{ + ModifierPrivate, + ModifierPackagePrivate, + ModifierProtected, + ModifierPublic, +} + +func formatModifier(modifier Modifier) string { + switch modifier { + case ModifierPrivate: + return "private" + case ModifierPackagePrivate: + return "" + case ModifierProtected: + return "protected" + case ModifierPublic: + return "public" + case ModifierAbstract: + return "abstract" + case ModifierStatic: + return "static" + case ModifierFinal: + return "final" + default: + return "" + } +} + +func formatModifiers(modifiers []Modifier) string { + // assume slice is already sorted and deduplicated + var sb strings.Builder + + for i, mod := range modifiers { + if formatted := formatModifier(mod); formatted != "" { + sb.WriteString(formatted) + if i != len(modifiers)-1 { + sb.WriteString(" ") + } + } + } + + return sb.String() +} diff --git a/poet/name.go b/poet/name.go new file mode 100644 index 0000000..e9f5bea --- /dev/null +++ b/poet/name.go @@ -0,0 +1,190 @@ +package poet + +import ( + "strings" +) + +// TODO - annotation support + +type TypeName struct { + Package string + Name string + + IsBuiltin bool + IsArray bool + + IsParameterized bool + Parameters []TypeName + + IsGeneric bool + Extends []TypeName +} + +func NewClassName(pkg, name string) TypeName { + return TypeName{Package: pkg, Name: name} +} + +func (t TypeName) Array() TypeName { + t.IsArray = true + return t +} + +func NewParameterizedClassName(pkg, name string, parameters ...TypeName) TypeName { + return TypeName{ + Package: pkg, + Name: name, + IsParameterized: true, + Parameters: parameters, + } +} + +func NewGenericParam(name string, extends ...TypeName) TypeName { + return TypeName{ + Package: "", + Name: name, + IsGeneric: true, + Extends: extends, + } +} + +func (t TypeName) Equals(other TypeName) bool { + if t.Package != other.Package || + t.Name != other.Name || + t.IsParameterized != other.IsParameterized || + t.IsGeneric != other.IsGeneric || + t.IsArray != other.IsArray { + return false + } + + if len(t.Parameters) != len(other.Parameters) { + return false + } + for i := range t.Parameters { + if !t.Parameters[i].Equals(other.Parameters[i]) { + return false + } + } + + if len(t.Extends) != len(other.Extends) { + return false + } + // TODO - is order actually important? + for i := range t.Extends { + if !t.Extends[i].Equals(other.Extends[i]) { + return false + } + } + + return true +} + +type FormatOption int + +var ( + ExcludeConstraints FormatOption = 1 << 0 + ExcludeParameters FormatOption = 1 << 1 + ExcludeArrayBraces FormatOption = 1 << 2 +) + +func (opt FormatOption) has(other FormatOption) bool { + return opt&other != 0 +} + +func (t TypeName) Format(ctx *Context, options ...FormatOption) string { + var opts FormatOption + for _, opt := range options { + opts |= opt + } + + var bld strings.Builder + + var typename string + if t.Package != "" { + // check if we need to use the fully qualified type name due to a collision + existing, ok := ctx.Types[t.Name] + if (ok && !t.Equals(existing)) || t.Name == ctx.CurrentTypeName { + bld.WriteString(t.Package) + bld.WriteString(".") + typename = t.Package + "." + t.Name + } else { + if !t.IsBuiltin { + ctx.Import(t.Package) + } + typename = t.Name + } + } + + bld.WriteString(t.Name) + if t.IsGeneric && len(t.Extends) > 0 && !opts.has(ExcludeConstraints) { + bld.WriteString(" extends ") + for i, extend := range t.Extends { + // TODO - can a generic parameter have its own generic constraint? + bld.WriteString(extend.Format(ctx, opts)) + if i < len(t.Extends)-1 { + bld.WriteString(" & ") + } + } + } else if t.IsParameterized && !opts.has(ExcludeParameters) { + bld.WriteString("<") + for i, param := range t.Parameters { + bld.WriteString(param.Format(ctx, opts)) + if i < len(t.Parameters)-1 { + bld.WriteString(", ") + } + } + bld.WriteString(">") + } + + if !opts.has(ExcludeArrayBraces) && t.IsArray && !(t.IsGeneric && !opts.has(ExcludeConstraints)) { + bld.WriteString("[]") + } + + // generic type names are not necessarily unique within a file + if !t.IsGeneric { + ctx.Types[typename] = t + } // ctx.GenericTypes = append(ctx.GenericTypes, t) + + return bld.String() +} + +var ( + Bool = TypeName{Name: "boolean", IsBuiltin: true} + BoolBoxed = TypeName{Package: "java.lang", Name: "Boolean", IsBuiltin: true} + Byte = TypeName{Name: "byte", IsBuiltin: true} + ByteBoxed = TypeName{Package: "java.lang", Name: "Byte", IsBuiltin: true} + Char = TypeName{Name: "char", IsBuiltin: true} + CharBoxed = TypeName{Package: "java.lang", Name: "Character", IsBuiltin: true} + Double = TypeName{Name: "double", IsBuiltin: true} + DoubleBoxed = TypeName{Package: "java.lang", Name: "Double", IsBuiltin: true} + Float = TypeName{Name: "float", IsBuiltin: true} + FloatBoxed = TypeName{Package: "java.lang", Name: "Float", IsBuiltin: true} + Int = TypeName{Name: "int", IsBuiltin: true} + IntBoxed = TypeName{Package: "java.lang", Name: "Integer", IsBuiltin: true} + Long = TypeName{Name: "long", IsBuiltin: true} + LongBoxed = TypeName{Package: "java.lang", Name: "Long", IsBuiltin: true} + Object = TypeName{Package: "java.lang", Name: "Object", IsBuiltin: true} + Short = TypeName{Name: "short", IsBuiltin: true} + ShortBoxed = TypeName{Package: "java.lang", Name: "Short", IsBuiltin: true} + String = TypeName{Package: "java.lang", Name: "String", IsBuiltin: true} + Void = TypeName{Name: "void", IsBuiltin: true} + VoidBoxed = TypeName{Package: "java.lang", Name: "Void", IsBuiltin: true} + Wildcard = TypeName{Name: "?", IsBuiltin: true} +) + +func newSingleParameterType(pkg, name string) func(TypeName) TypeName { + return func(t TypeName) TypeName { + return NewParameterizedClassName(pkg, name, t) + } +} + +func newTwoParameterType(pkg, name string) func(TypeName, TypeName) TypeName { + return func(t1 TypeName, t2 TypeName) TypeName { + return NewParameterizedClassName(pkg, name, t1, t2) + } +} + +var ( + ListOf = newSingleParameterType("java.util", "List") + MapOf = newTwoParameterType("java.util", "Map") + SetOf = newSingleParameterType("java.util", "Set") +) diff --git a/poet/name_test.go b/poet/name_test.go new file mode 100644 index 0000000..60726e5 --- /dev/null +++ b/poet/name_test.go @@ -0,0 +1,199 @@ +package poet + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFormatClassName_InCurrentPackage(t *testing.T) { + assert := assert.New(t) + + name := NewClassName("io.github.tandemdude", "Foo") + + ctx := NewContext("io.github.tandemdude") + assert.Equal("Foo", name.Format(ctx)) + assert.Empty(ctx.Imports) +} + +func TestFormatClassName_NotInCurrentPackage(t *testing.T) { + assert := assert.New(t) + + name := NewClassName("io.github.tandemdude", "Foo") + + ctx := NewContext("com.example") + assert.Equal("Foo", name.Format(ctx)) + assert.Len(ctx.Imports, 1) + assert.Equal("io.github.tandemdude", ctx.Imports[0]) +} + +func TestFormatClassName_ConflictingNames(t *testing.T) { + assert := assert.New(t) + + name := NewClassName("io.github.tandemdude", "Foo") + name2 := NewClassName("io.github.davfsa", "Foo") + + ctx := NewContext("com.example") + assert.Equal("Foo", name.Format(ctx)) + assert.Equal("io.github.davfsa.Foo", name2.Format(ctx)) + assert.Len(ctx.Imports, 1) + assert.Equal("io.github.tandemdude", ctx.Imports[0]) +} + +func TestFormatParameterizedClassName_NoParameters(t *testing.T) { + assert := assert.New(t) + + name := NewParameterizedClassName("java.util", "List") + + ctx := NewContext("io.github.tandemdude") + assert.Equal("List<>", name.Format(ctx)) + assert.Len(ctx.Imports, 1) + assert.Equal("java.util", ctx.Imports[0]) +} + +func TestFormatParameterizedClassName_SingleParameter(t *testing.T) { + assert := assert.New(t) + + name := NewParameterizedClassName("java.util", "List", NewClassName("io.github.tandemdude", "Foo")) + + ctx := NewContext("io.github.tandemdude") + assert.Equal("List", name.Format(ctx)) + assert.Len(ctx.Imports, 1) + assert.Equal("java.util", ctx.Imports[0]) +} + +func TestFormatParameterizedClassName_MultipleParameters(t *testing.T) { + assert := assert.New(t) + + name := NewParameterizedClassName( + "java.util", "Map", + NewClassName("io.github.tandemdude", "Foo"), + NewClassName("io.github.tandemdude", "Bar"), + ) + + ctx := NewContext("io.github.tandemdude") + assert.Equal("Map", name.Format(ctx)) + assert.Len(ctx.Imports, 1) + assert.Equal("java.util", ctx.Imports[0]) +} + +func TestFormatGenericParam_NoConstraints(t *testing.T) { + assert := assert.New(t) + + name := NewGenericParam("T") + + ctx := NewContext("io.github.tandemdude") + assert.Equal("T", name.Format(ctx)) + assert.Empty(ctx.Imports) +} + +func TestFormatGenericParam_OneConstraint(t *testing.T) { + assert := assert.New(t) + + name := NewGenericParam("T", NewClassName("io.github.tandemdude", "Foo")) + + ctx := NewContext("io.github.tandemdude") + assert.Equal("T extends Foo", name.Format(ctx)) + assert.Empty(ctx.Imports) +} + +func TestFormatGenericParam_MultipleConstraints(t *testing.T) { + assert := assert.New(t) + + name := NewGenericParam( + "T", + NewClassName("io.github.tandemdude", "Foo"), + NewClassName("io.github.tandemdude", "Bar"), + ) + + ctx := NewContext("io.github.tandemdude") + assert.Equal("T extends Foo & Bar", name.Format(ctx)) + assert.Empty(ctx.Imports) +} + +func TestFormatClassName_IsArray(t *testing.T) { + assert := assert.New(t) + + name := NewClassName("io.github.tandemdude", "Foo").Array() + + ctx := NewContext("io.github.tandemdude") + assert.Equal("Foo[]", name.Format(ctx)) + assert.Empty(ctx.Imports) +} + +func TestFormatParameterizedClassName_IsArray(t *testing.T) { + assert := assert.New(t) + + name := NewParameterizedClassName("io.github.tandemdude", "Foo", String).Array() + + ctx := NewContext("io.github.tandemdude") + assert.Equal("Foo[]", name.Format(ctx)) + assert.Empty(ctx.Imports) +} + +func TestFormatGenericParam_IsArray(t *testing.T) { + assert := assert.New(t) + + name := NewGenericParam("T").Array() + + ctx := NewContext("io.github.tandemdude") + assert.Equal("T[]", name.Format(ctx, ExcludeConstraints)) + assert.Empty(ctx.Imports) +} + +func TestClassName_SimpleEquals(t *testing.T) { + assert := assert.New(t) + + name := NewClassName("", "Foo") + pName := NewParameterizedClassName("", "Foo") + gName := NewGenericParam("Foo") + aName := name.Array() + + // the same instances are equal to each other + assert.Equal(name, name) + assert.Equal(pName, pName) + assert.Equal(gName, gName) + assert.Equal(aName, aName) + + // different types of class names are not equal + assert.NotEqual(name, pName) + assert.NotEqual(name, gName) + assert.NotEqual(name, aName) + + assert.NotEqual(pName, gName) + assert.NotEqual(pName, aName) + + assert.NotEqual(gName, aName) +} + +func TestParameterizedClassName_Equals(t *testing.T) { + assert := assert.New(t) + + name := NewParameterizedClassName("", "Foo") + name2 := NewParameterizedClassName("", "Foo", String) + + // types with different parameter counts are not equal + assert.NotEqual(name, name2) + + name = NewParameterizedClassName("", "Foo", String) + name2 = NewParameterizedClassName("", "Foo", IntBoxed) + + // types with the same parameter count, but different parameters are not equal + assert.NotEqual(name, name2) +} + +func TestGenericParam_Equals(t *testing.T) { + assert := assert.New(t) + + name := NewGenericParam("T") + name2 := NewGenericParam("T", String) + + // generic parameters with different constraint counts are not equal + assert.NotEqual(name, name2) + + name = NewGenericParam("T", String) + name2 = NewGenericParam("T", IntBoxed) + + // generic parameters with the same constraint counts, but different constraints are not equal + assert.NotEqual(name, name2) +} diff --git a/poet/type.go b/poet/type.go new file mode 100644 index 0000000..0ba1992 --- /dev/null +++ b/poet/type.go @@ -0,0 +1,291 @@ +package poet + +import ( + "fmt" + "strings" +) + +// TODO - annotation support + +type ClassField struct { + Name string + Type TypeName + Modifiers []Modifier +} + +type Class struct { + Name string + + Modifiers []Modifier + GenericParameters []TypeName + Constructor *Constructor + Fields []ClassField + Methods []Method +} + +func (c Class) name() string { + return c.Name +} + +func (c Class) Format(ctx *Context) string { + var sb strings.Builder + + sb.WriteString(formatModifiers(c.Modifiers)) + if sb.Len() > 0 { + sb.WriteString(" ") + } + + sb.WriteString("class ") + sb.WriteString(c.Name) + writeGenericParamList(ctx, &sb, c.GenericParameters, false) + sb.WriteString(" {\n") + + for i, field := range c.Fields { + sb.WriteString(ctx.indent(fmt.Sprintf( + "%s %s %s;\n", + formatModifiers(field.Modifiers), + field.Type.Format(ctx, ExcludeConstraints), + field.Name, + ))) + + if i == len(c.Fields)-1 { + sb.WriteString("\n") + } + } + + if c.Constructor != nil { + c.Constructor.Name = c.Name + sb.WriteString(ctx.indent(c.Constructor.Format(ctx))) + sb.WriteString("\n") + } + + if len(c.Methods) > 0 { + sb.WriteString("\n") + } + + for i, method := range c.Methods { + sb.WriteString(ctx.indent(method.Format(ctx))) + sb.WriteString("\n") + if i != len(c.Methods)-1 { + sb.WriteString("\n") + } + } + + sb.WriteString("}") + return sb.String() +} + +type ClassBuilder struct { + class Class +} + +func NewClassBuilder(name string) *ClassBuilder { + return &ClassBuilder{class: Class{Name: name}} +} + +func (c *ClassBuilder) WithModifiers(modifiers ...Modifier) *ClassBuilder { + c.class.Modifiers = appendModifiers(c.class.Modifiers, modifiers) + return c +} + +func (c *ClassBuilder) WithGenericParameters(parameters ...TypeName) *ClassBuilder { + c.class.GenericParameters = append(c.class.GenericParameters, parameters...) + return c +} + +func (c *ClassBuilder) WithConstructor(constructor Constructor) *ClassBuilder { + c.class.Constructor = &constructor + return c +} + +func (c *ClassBuilder) WithFields(fields ...ClassField) *ClassBuilder { + c.class.Fields = append(c.class.Fields, fields...) + return c +} + +func (c *ClassBuilder) WithMethods(methods ...Method) *ClassBuilder { + c.class.Methods = append(c.class.Methods, methods...) + return c +} + +func (c *ClassBuilder) Build() Class { + c.class.Modifiers = maybeSetPackagePrivate(c.class.Modifiers) + return c.class +} + +type EnumValue struct { + Name string + // for now, only support string enums given that is the only type supported by the databases + Value string +} + +func NewEnumValue(name string, value string) EnumValue { + return EnumValue{Name: name, Value: value} +} + +type Enum struct { + Name string + + Modifiers []Modifier + Values []EnumValue + Methods []Method +} + +func (e Enum) name() string { + return e.Name +} + +func (e Enum) Format(ctx *Context) string { + var sb strings.Builder + + sb.WriteString(formatModifiers(e.Modifiers)) + if sb.Len() > 0 { + sb.WriteString(" ") + } + + sb.WriteString("enum ") + sb.WriteString(e.Name) + sb.WriteString(" {\n") + + for i, v := range e.Values { + sb.WriteString(ctx.indent(fmt.Sprintf("%s(\"%s\")", v.Name, v.Value))) + + if i != len(e.Values)-1 { + sb.WriteString(",\n") + } else { + sb.WriteString(";\n\n") + } + } + + sb.WriteString(ctx.indent("private final String value;\n\n")) + sb.WriteString(ctx.indent(e.Name)) + sb.WriteString("(final String value) { this.value = value; }\n") + + if len(e.Methods) > 0 { + sb.WriteString("\n") + } + + for i, method := range e.Methods { + sb.WriteString(ctx.indent(method.Format(ctx))) + sb.WriteString("\n") + if i != len(e.Methods)-1 { + sb.WriteString("\n") + } + } + + sb.WriteString("}") + + return sb.String() +} + +type EnumBuilder struct { + enum Enum +} + +func NewEnumBuilder(name string) *EnumBuilder { + return &EnumBuilder{enum: Enum{Name: name}} +} + +func (b *EnumBuilder) WithModifiers(modifiers ...Modifier) *EnumBuilder { + b.enum.Modifiers = appendModifiers(b.enum.Modifiers, modifiers) + return b +} + +func (b *EnumBuilder) WithValue(name string, value string) *EnumBuilder { + b.enum.Values = append(b.enum.Values, EnumValue{Name: name, Value: value}) + return b +} + +func (b *EnumBuilder) WithValues(values ...EnumValue) *EnumBuilder { + b.enum.Values = append(b.enum.Values, values...) + return b +} + +func (b *EnumBuilder) WithMethods(methods ...Method) *EnumBuilder { + b.enum.Methods = append(b.enum.Methods, methods...) + return b +} + +func (b *EnumBuilder) Build() Enum { + b.enum.Modifiers = maybeSetPackagePrivate(b.enum.Modifiers) + return b.enum +} + +type Record struct { + Name string + + Modifiers []Modifier + Parameters []MethodParameter + Methods []Method +} + +func (r Record) name() string { + return r.Name +} + +func (r Record) Format(ctx *Context) string { + var sb strings.Builder + + sb.WriteString(formatModifiers(r.Modifiers)) + if sb.Len() > 0 { + sb.WriteString(" ") + } + + // TODO - generic parameter support? + + sb.WriteString("record ") + sb.WriteString(r.Name) + sb.WriteString("(") + + for i, param := range r.Parameters { + sb.WriteString(param.Format(ctx)) + if i != len(r.Parameters)-1 { + sb.WriteString(", ") + } + } + sb.WriteString(") ") + + if len(r.Methods) == 0 { + sb.WriteString("{}") + return sb.String() + } + + for i, method := range r.Methods { + sb.WriteString(ctx.indent(method.Format(ctx))) + sb.WriteString("\n") + if i != len(r.Methods)-1 { + sb.WriteString("\n") + } + } + sb.WriteString("}") + + return sb.String() +} + +type RecordBuilder struct { + record Record +} + +func NewRecordBuilder(name string) *RecordBuilder { + return &RecordBuilder{record: Record{Name: name}} +} + +func (b *RecordBuilder) WithModifiers(modifiers ...Modifier) *RecordBuilder { + b.record.Modifiers = appendModifiers(b.record.Modifiers, modifiers) + return b +} + +func (b *RecordBuilder) WithParameters(params ...MethodParameter) *RecordBuilder { + b.record.Parameters = append(b.record.Parameters, params...) + return b +} + +func (b *RecordBuilder) WithMethods(methods ...Method) *RecordBuilder { + b.record.Methods = append(b.record.Methods, methods...) + return b +} + +func (b *RecordBuilder) Build() Record { + b.record.Modifiers = maybeSetPackagePrivate(b.record.Modifiers) + return b.record +} diff --git a/poet/utils.go b/poet/utils.go new file mode 100644 index 0000000..cadb040 --- /dev/null +++ b/poet/utils.go @@ -0,0 +1,40 @@ +package poet + +import ( + "slices" + "strings" +) + +func appendModifiers(initial []Modifier, new []Modifier) []Modifier { + initial = append(initial, new...) + slices.Sort(initial) + return slices.Compact(initial) +} + +func maybeSetPackagePrivate(modifiers []Modifier) []Modifier { + if !slices.ContainsFunc(modifiers, func(m Modifier) bool { + return slices.Contains(accessModifiers, m) + }) { + modifiers = append(modifiers, ModifierPackagePrivate) + slices.Sort(modifiers) + } + return modifiers +} + +func writeGenericParamList(ctx *Context, sb *strings.Builder, params []TypeName, includeTrailingSpace bool) { + if len(params) == 0 { + return + } + + sb.WriteString("<") + for i, param := range params { + sb.WriteString(param.Format(ctx, ExcludeConstraints)) + if i < len(params)-1 { + sb.WriteString(", ") + } + } + sb.WriteString(">") + if includeTrailingSpace { + sb.WriteString(" ") + } +}