Skip to content

Commit 719379c

Browse files
committed
[JExtract] Static 'call' method in binding descriptor classes
* Add a static call method to each binding descriptor class to handle the actual downcall. * Refactor wrapper methods to delegate to the binding descriptor's call method. * Clearly separate responsibilities: each binding descriptor class now encapsulates the complete lowered Cdecl thunk, while wrapper methods focus on Java-to-Cdecl type conversion.
1 parent 5d58ecc commit 719379c

7 files changed

+251
-187
lines changed

Sources/JExtractSwift/Swift2JavaTranslator+JavaBindingsPrinting.swift

Lines changed: 59 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,27 @@ extension Swift2JavaTranslator {
2424
printJavaBindingDescriptorClass(&printer, decl)
2525

2626
// Render the "make the downcall" functions.
27-
printFuncDowncallMethod(&printer, decl)
27+
printJavaBindingWrapperMethod(&printer, decl)
2828
}
2929

3030
/// Print FFM Java binding descriptors for the imported Swift API.
31-
func printJavaBindingDescriptorClass(
31+
package func printJavaBindingDescriptorClass(
3232
_ printer: inout CodePrinter,
3333
_ decl: ImportedFunc
3434
) {
3535
let thunkName = thunkNameRegistry.functionThunkName(decl: decl)
3636
let cFunc = decl.cFunctionDecl(cName: thunkName)
3737

38-
printer.printBraceBlock("private static class \(cFunc.name)") { printer in
38+
printer.printBraceBlock(
39+
"""
40+
/**
41+
* {@snippet lang=c :
42+
* \(cFunc.description)
43+
* }
44+
*/
45+
private static class \(cFunc.name)
46+
"""
47+
) { printer in
3948
printFunctionDescriptorValue(&printer, cFunc)
4049
printer.print(
4150
"""
@@ -44,11 +53,12 @@ extension Swift2JavaTranslator {
4453
public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(ADDR, DESC);
4554
"""
4655
)
56+
printJavaBindingDowncallMethod(&printer, cFunc)
4757
}
4858
}
4959

5060
/// Print the 'FunctionDescriptor' of the lowered cdecl thunk.
51-
public func printFunctionDescriptorValue(
61+
func printFunctionDescriptorValue(
5262
_ printer: inout CodePrinter,
5363
_ cFunc: CFunction
5464
) {
@@ -74,9 +84,42 @@ extension Swift2JavaTranslator {
7484
printer.print(");")
7585
}
7686

87+
func printJavaBindingDowncallMethod(
88+
_ printer: inout CodePrinter,
89+
_ cFunc: CFunction
90+
) {
91+
let returnTy = cFunc.resultType.javaType
92+
let maybeReturn = cFunc.resultType.isVoid ? "" : "return (\(returnTy)) "
93+
94+
var params: [String] = []
95+
var args: [String] = []
96+
for param in cFunc.parameters {
97+
// ! unwrapping because cdecl lowering guarantees the parameter named.
98+
params.append("\(param.type.javaType) \(param.name!)")
99+
args.append(param.name!)
100+
}
101+
let paramsStr = params.joined(separator: ", ")
102+
let argsStr = args.joined(separator: ", ")
103+
104+
printer.print(
105+
"""
106+
public static \(returnTy) call(\(paramsStr)) {
107+
try {
108+
if (SwiftKit.TRACE_DOWNCALLS) {
109+
SwiftKit.traceDowncall(\(argsStr));
110+
}
111+
\(maybeReturn)HANDLE.invokeExact(\(argsStr));
112+
} catch (Throwable ex$) {
113+
throw new AssertionError("should not reach here", ex$);
114+
}
115+
}
116+
"""
117+
)
118+
}
119+
77120
/// Print the calling body that forwards all the parameters to the `methodName`,
78121
/// with adding `SwiftArena.ofAuto()` at the end.
79-
public func printFuncDowncallMethod(
122+
public func printJavaBindingWrapperMethod(
80123
_ printer: inout CodePrinter,
81124
_ decl: ImportedFunc) {
82125
let methodName: String = switch decl.kind {
@@ -130,19 +173,11 @@ extension Swift2JavaTranslator {
130173
_ printer: inout CodePrinter,
131174
_ decl: ImportedFunc
132175
) {
133-
//=== Part 1: MethodHandle
134-
let descriptorClassIdentifier = thunkNameRegistry.functionThunkName(decl: decl)
135-
printer.print(
136-
"var mh$ = \(descriptorClassIdentifier).HANDLE;"
137-
)
138-
139-
let tryHead = if decl.translatedSignature.requiresTemporaryArena {
140-
"try(var arena$ = Arena.ofConfined()) {"
141-
} else {
142-
"try {"
176+
//=== Part 1: prepare temporary arena if needed.
177+
if decl.translatedSignature.requiresTemporaryArena {
178+
printer.print("try(var arena$ = Arena.ofConfined()) {")
179+
printer.indent();
143180
}
144-
printer.print(tryHead);
145-
printer.indent();
146181

147182
//=== Part 2: prepare all arguments.
148183
var downCallArguments: [String] = []
@@ -151,15 +186,7 @@ extension Swift2JavaTranslator {
151186
for (i, parameter) in decl.translatedSignature.parameters.enumerated() {
152187
let original = decl.swiftSignature.parameters[i]
153188
let parameterName = original.parameterName ?? "_\(i)"
154-
let converted = parameter.conversion.render(&printer, parameterName)
155-
let lowered: String
156-
if parameter.conversion.isTrivial {
157-
lowered = converted
158-
} else {
159-
// Store the conversion to a temporary variable.
160-
lowered = "\(parameterName)$"
161-
printer.print("var \(lowered) = \(converted);")
162-
}
189+
let lowered = parameter.conversion.render(&printer, parameterName)
163190
downCallArguments.append(lowered)
164191
}
165192

@@ -191,14 +218,8 @@ extension Swift2JavaTranslator {
191218
}
192219

193220
//=== Part 3: Downcall.
194-
printer.print(
195-
"""
196-
if (SwiftKit.TRACE_DOWNCALLS) {
197-
SwiftKit.traceDowncall(\(downCallArguments.joined(separator: ", ")));
198-
}
199-
"""
200-
)
201-
let downCall = "mh$.invokeExact(\(downCallArguments.joined(separator: ", ")))"
221+
let thunkName = thunkNameRegistry.functionThunkName(decl: decl)
222+
let downCall = "\(thunkName).call(\(downCallArguments.joined(separator: ", ")))"
202223

203224
//=== Part 4: Convert the return value.
204225
if decl.translatedSignature.result.javaResultType == .void {
@@ -221,14 +242,10 @@ extension Swift2JavaTranslator {
221242
}
222243
}
223244

224-
printer.outdent()
225-
printer.print(
226-
"""
227-
} catch (Throwable ex$) {
228-
throw new AssertionError("should not reach here", ex$);
229-
}
230-
"""
231-
)
245+
if decl.translatedSignature.requiresTemporaryArena {
246+
printer.outdent()
247+
printer.print("}")
248+
}
232249
}
233250

234251
func renderMemoryLayoutValue(for javaType: JavaType) -> String {

Sources/JExtractSwift/Swift2JavaTranslator+JavaTranslation.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ struct JavaTranslation {
270270
return TranslatedResult(
271271
javaResultType: javaType,
272272
outParameters: [],
273-
conversion: .cast(javaType)
273+
conversion: .pass
274274
)
275275
}
276276

Tests/JExtractSwiftTests/FuncCallbackImportTests.swift

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ final class FuncCallbackImportTests {
4545
let funcDecl = st.importedGlobalFuncs.first { $0.name == "callMe" }!
4646

4747
let output = CodePrinter.toString { printer in
48-
st.printFuncDowncallMethod(&printer, funcDecl)
48+
st.printJavaBindingWrapperMethod(&printer, funcDecl)
4949
}
5050

5151
assertOutput(
@@ -59,15 +59,8 @@ final class FuncCallbackImportTests {
5959
* }
6060
*/
6161
public static void callMe(java.lang.Runnable callback) {
62-
var mh$ = swiftjava___FakeModule_callMe_callback.HANDLE;
6362
try(var arena$ = Arena.ofConfined()) {
64-
var callback$ = SwiftKit.toUpcallStub(callback, arena$);
65-
if (SwiftKit.TRACE_DOWNCALLS) {
66-
SwiftKit.traceDowncall(callback$);
67-
}
68-
mh$.invokeExact(callback$);
69-
} catch (Throwable ex$) {
70-
throw new AssertionError("should not reach here", ex$);
63+
swiftjava___FakeModule_callMe_callback.call(SwiftKit.toUpcallStub(callback, arena$))
7164
}
7265
}
7366
"""

Tests/JExtractSwiftTests/FunctionDescriptorImportTests.swift

Lines changed: 121 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,29 @@ final class FunctionDescriptorTests {
5151
output,
5252
expected:
5353
"""
54-
public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid(
55-
/* i: */SwiftValueLayout.SWIFT_INT
56-
);
54+
/**
55+
* {@snippet lang=c :
56+
* void swiftjava_SwiftModule_globalTakeInt_i(ptrdiff_t i)
57+
* }
58+
*/
59+
private static class swiftjava_SwiftModule_globalTakeInt_i {
60+
public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid(
61+
/* i: */SwiftValueLayout.SWIFT_INT
62+
);
63+
public static final MemorySegment ADDR =
64+
SwiftModule.findOrThrow("swiftjava_SwiftModule_globalTakeInt_i");
65+
public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(ADDR, DESC);
66+
public static void call(long i) {
67+
try {
68+
if (SwiftKit.TRACE_DOWNCALLS) {
69+
SwiftKit.traceDowncall(i);
70+
}
71+
HANDLE.invokeExact(i);
72+
} catch (Throwable ex$) {
73+
throw new AssertionError("should not reach here", ex$);
74+
}
75+
}
76+
}
5777
"""
5878
)
5979
}
@@ -66,10 +86,30 @@ final class FunctionDescriptorTests {
6686
output,
6787
expected:
6888
"""
69-
public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid(
70-
/* l: */SwiftValueLayout.SWIFT_INT64,
71-
/* i32: */SwiftValueLayout.SWIFT_INT32
72-
);
89+
/**
90+
* {@snippet lang=c :
91+
* void swiftjava_SwiftModule_globalTakeLongInt_l_i32(int64_t l, int32_t i32)
92+
* }
93+
*/
94+
private static class swiftjava_SwiftModule_globalTakeLongInt_l_i32 {
95+
public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid(
96+
/* l: */SwiftValueLayout.SWIFT_INT64,
97+
/* i32: */SwiftValueLayout.SWIFT_INT32
98+
);
99+
public static final MemorySegment ADDR =
100+
SwiftModule.findOrThrow("swiftjava_SwiftModule_globalTakeLongInt_l_i32");
101+
public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(ADDR, DESC);
102+
public static void call(long l, int i32) {
103+
try {
104+
if (SwiftKit.TRACE_DOWNCALLS) {
105+
SwiftKit.traceDowncall(l, i32);
106+
}
107+
HANDLE.invokeExact(l, i32);
108+
} catch (Throwable ex$) {
109+
throw new AssertionError("should not reach here", ex$);
110+
}
111+
}
112+
}
73113
"""
74114
)
75115
}
@@ -82,10 +122,30 @@ final class FunctionDescriptorTests {
82122
output,
83123
expected:
84124
"""
85-
public static final FunctionDescriptor DESC = FunctionDescriptor.of(
86-
/* -> */SwiftValueLayout.SWIFT_INT,
87-
/* i: */SwiftValueLayout.SWIFT_INT
88-
);
125+
/**
126+
* {@snippet lang=c :
127+
* ptrdiff_t swiftjava_SwiftModule_echoInt_i(ptrdiff_t i)
128+
* }
129+
*/
130+
private static class swiftjava_SwiftModule_echoInt_i {
131+
public static final FunctionDescriptor DESC = FunctionDescriptor.of(
132+
/* -> */SwiftValueLayout.SWIFT_INT,
133+
/* i: */SwiftValueLayout.SWIFT_INT
134+
);
135+
public static final MemorySegment ADDR =
136+
SwiftModule.findOrThrow("swiftjava_SwiftModule_echoInt_i");
137+
public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(ADDR, DESC);
138+
public static long call(long i) {
139+
try {
140+
if (SwiftKit.TRACE_DOWNCALLS) {
141+
SwiftKit.traceDowncall(i);
142+
}
143+
return (long) HANDLE.invokeExact(i);
144+
} catch (Throwable ex$) {
145+
throw new AssertionError("should not reach here", ex$);
146+
}
147+
}
148+
}
89149
"""
90150
)
91151
}
@@ -98,10 +158,30 @@ final class FunctionDescriptorTests {
98158
output,
99159
expected:
100160
"""
101-
public static final FunctionDescriptor DESC = FunctionDescriptor.of(
102-
/* -> */SwiftValueLayout.SWIFT_INT32,
103-
/* self: */SwiftValueLayout.SWIFT_POINTER
104-
);
161+
/**
162+
* {@snippet lang=c :
163+
* int32_t swiftjava_SwiftModule_MySwiftClass_counter$get(const void *self)
164+
* }
165+
*/
166+
private static class swiftjava_SwiftModule_MySwiftClass_counter$get {
167+
public static final FunctionDescriptor DESC = FunctionDescriptor.of(
168+
/* -> */SwiftValueLayout.SWIFT_INT32,
169+
/* self: */SwiftValueLayout.SWIFT_POINTER
170+
);
171+
public static final MemorySegment ADDR =
172+
SwiftModule.findOrThrow("swiftjava_SwiftModule_MySwiftClass_counter$get");
173+
public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(ADDR, DESC);
174+
public static int call(java.lang.foreign.MemorySegment self) {
175+
try {
176+
if (SwiftKit.TRACE_DOWNCALLS) {
177+
SwiftKit.traceDowncall(self);
178+
}
179+
return (int) HANDLE.invokeExact(self);
180+
} catch (Throwable ex$) {
181+
throw new AssertionError("should not reach here", ex$);
182+
}
183+
}
184+
}
105185
"""
106186
)
107187
}
@@ -113,10 +193,30 @@ final class FunctionDescriptorTests {
113193
output,
114194
expected:
115195
"""
116-
public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid(
117-
/* newValue: */SwiftValueLayout.SWIFT_INT32,
118-
/* self: */SwiftValueLayout.SWIFT_POINTER
119-
);
196+
/**
197+
* {@snippet lang=c :
198+
* void swiftjava_SwiftModule_MySwiftClass_counter$set(int32_t newValue, const void *self)
199+
* }
200+
*/
201+
private static class swiftjava_SwiftModule_MySwiftClass_counter$set {
202+
public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid(
203+
/* newValue: */SwiftValueLayout.SWIFT_INT32,
204+
/* self: */SwiftValueLayout.SWIFT_POINTER
205+
);
206+
public static final MemorySegment ADDR =
207+
SwiftModule.findOrThrow("swiftjava_SwiftModule_MySwiftClass_counter$set");
208+
public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(ADDR, DESC);
209+
public static void call(int newValue, java.lang.foreign.MemorySegment self) {
210+
try {
211+
if (SwiftKit.TRACE_DOWNCALLS) {
212+
SwiftKit.traceDowncall(newValue, self);
213+
}
214+
HANDLE.invokeExact(newValue, self);
215+
} catch (Throwable ex$) {
216+
throw new AssertionError("should not reach here", ex$);
217+
}
218+
}
219+
}
120220
"""
121221
)
122222
}
@@ -145,10 +245,8 @@ extension FunctionDescriptorTests {
145245
$0.name == methodIdentifier
146246
}!
147247

148-
let thunkName = st.thunkNameRegistry.functionThunkName(decl: funcDecl)
149-
let cFunc = funcDecl.cFunctionDecl(cName: thunkName)
150248
let output = CodePrinter.toString { printer in
151-
st.printFunctionDescriptorValue(&printer, cFunc)
249+
st.printJavaBindingDescriptorClass(&printer, funcDecl)
152250
}
153251

154252
try body(output)
@@ -180,10 +278,8 @@ extension FunctionDescriptorTests {
180278
fatalError("Cannot find descriptor of: \(identifier)")
181279
}
182280

183-
let thunkName = st.thunkNameRegistry.functionThunkName(decl: accessorDecl)
184-
let cFunc = accessorDecl.cFunctionDecl(cName: thunkName)
185281
let getOutput = CodePrinter.toString { printer in
186-
st.printFunctionDescriptorValue(&printer, cFunc)
282+
st.printJavaBindingDescriptorClass(&printer, accessorDecl)
187283
}
188284

189285
try body(getOutput)

0 commit comments

Comments
 (0)