@@ -6,7 +6,7 @@ import SwiftSyntaxMacros
6
6
7
7
protocol ParamInfo : CustomStringConvertible {
8
8
var description : String { get }
9
- var original : ExprSyntax { get }
9
+ var original : SyntaxProtocol { get }
10
10
var pointerIndex : Int { get }
11
11
var nonescaping : Bool { get set }
12
12
@@ -16,12 +16,31 @@ protocol ParamInfo: CustomStringConvertible {
16
16
) -> BoundsCheckedThunkBuilder
17
17
}
18
18
19
+ struct CxxSpan : ParamInfo {
20
+ var pointerIndex : Int
21
+ var nonescaping : Bool
22
+ var original : SyntaxProtocol
23
+ var typeMappings : [ String : String ]
24
+
25
+ var description : String {
26
+ return " std::span(pointer: \( pointerIndex) , nonescaping: \( nonescaping) ) "
27
+ }
28
+
29
+ func getBoundsCheckedThunkBuilder(
30
+ _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionDeclSyntax ,
31
+ _ variant: Variant
32
+ ) -> BoundsCheckedThunkBuilder {
33
+ CxxSpanThunkBuilder ( base: base, index: pointerIndex - 1 , signature: funcDecl. signature,
34
+ typeMappings: typeMappings, node: original)
35
+ }
36
+ }
37
+
19
38
struct CountedBy : ParamInfo {
20
39
var pointerIndex : Int
21
40
var count : ExprSyntax
22
41
var sizedBy : Bool
23
42
var nonescaping : Bool
24
- var original : ExprSyntax
43
+ var original : SyntaxProtocol
25
44
26
45
var description : String {
27
46
if sizedBy {
@@ -43,11 +62,12 @@ struct CountedBy: ParamInfo {
43
62
nonescaping: nonescaping, isSizedBy: sizedBy)
44
63
}
45
64
}
65
+
46
66
struct EndedBy : ParamInfo {
47
67
var pointerIndex : Int
48
68
var endIndex : Int
49
69
var nonescaping : Bool
50
- var original : ExprSyntax
70
+ var original : SyntaxProtocol
51
71
52
72
var description : String {
53
73
return " .endedBy(start: \( pointerIndex) , end: \( endIndex) , nonescaping: \( nonescaping) ) "
@@ -196,6 +216,7 @@ func getParam(_ signature: FunctionSignatureSyntax, _ paramIndex: Int) -> Functi
196
216
return params [ params. startIndex]
197
217
}
198
218
}
219
+
199
220
func getParam( _ funcDecl: FunctionDeclSyntax , _ paramIndex: Int ) -> FunctionParameterSyntax {
200
221
return getParam ( funcDecl. signature, paramIndex)
201
222
}
@@ -256,6 +277,43 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
256
277
}
257
278
}
258
279
280
+ struct CxxSpanThunkBuilder : BoundsCheckedThunkBuilder {
281
+ public let base : BoundsCheckedThunkBuilder
282
+ public let index : Int
283
+ public let signature : FunctionSignatureSyntax
284
+ public let typeMappings : [ String : String ]
285
+ public let node : SyntaxProtocol
286
+
287
+ func buildBoundsChecks( _ variant: Variant ) throws -> [ CodeBlockItemSyntax . Item ] {
288
+ return [ ]
289
+ }
290
+
291
+ func buildFunctionSignature( _ argTypes: [ Int : TypeSyntax ? ] , _ variant: Variant ) throws
292
+ -> FunctionSignatureSyntax {
293
+ var types = argTypes
294
+ let param = getParam ( signature, index)
295
+ let typeName = try getTypeName ( param. type) . text;
296
+ guard let desugaredType = typeMappings [ typeName] else {
297
+ throw DiagnosticError (
298
+ " unable to desugar type with name ' \( typeName) ' " , node: node)
299
+ }
300
+
301
+ let parsedDesugaredType = try TypeSyntax ( " \( raw: desugaredType) " )
302
+ types [ index] = TypeSyntax ( IdentifierTypeSyntax ( name: " Span " ,
303
+ genericArgumentClause: parsedDesugaredType. as ( IdentifierTypeSyntax . self) !. genericArgumentClause) )
304
+ return try base. buildFunctionSignature ( types, variant)
305
+ }
306
+
307
+ func buildFunctionCall( _ pointerArgs: [ Int : ExprSyntax ] , _ variant: Variant ) throws -> ExprSyntax {
308
+ var args = pointerArgs
309
+ let param = getParam ( signature, index)
310
+ let typeName = try getTypeName ( param. type) . text;
311
+ assert ( args [ index] == nil )
312
+ args [ index] = ExprSyntax ( " \( raw: typeName) ( \( raw: param. secondName ?? param. firstName) ) " )
313
+ return try base. buildFunctionCall ( args, variant)
314
+ }
315
+ }
316
+
259
317
protocol PointerBoundsThunkBuilder : BoundsCheckedThunkBuilder {
260
318
var name : TokenSyntax { get }
261
319
var nullable : Bool { get }
@@ -460,7 +518,8 @@ func getParameterIndexForDeclRef(
460
518
/// Depends on bounds, escapability and lifetime information for each pointer.
461
519
/// Intended to map to C attributes like __counted_by, __ended_by and __no_escape,
462
520
/// for automatic application by ClangImporter when the C declaration is annotated
463
- /// appropriately.
521
+ /// appropriately. Moreover, it can wrap C++ APIs using unsafe C++ types like
522
+ /// std::span with APIs that use their safer Swift equivalents.
464
523
public struct SwiftifyImportMacro : PeerMacro {
465
524
static func parseEnumName( _ enumConstructorExpr: FunctionCallExprSyntax ) throws -> String {
466
525
guard let calledExpr = enumConstructorExpr. calledExpression. as ( MemberAccessExprSyntax . self)
@@ -557,6 +616,54 @@ public struct SwiftifyImportMacro: PeerMacro {
557
616
return pointerParamIndex
558
617
}
559
618
619
+ static func parseTypeMappingParam( _ paramAST: LabeledExprSyntax ? ) throws -> [ String : String ] ? {
620
+ guard let unwrappedParamAST = paramAST else {
621
+ return nil
622
+ }
623
+ let paramExpr = unwrappedParamAST. expression
624
+ guard let dictExpr = paramExpr. as ( DictionaryExprSyntax . self) else {
625
+ return nil
626
+ }
627
+ var dict : [ String : String ] = [ : ]
628
+ switch dictExpr. content {
629
+ case . colon( _) :
630
+ return dict
631
+ case . elements( let types) :
632
+ for element in types {
633
+ guard let key = element. key. as ( StringLiteralExprSyntax . self) else {
634
+ throw DiagnosticError ( " expected a string literal, got ' \( element. key) ' " , node: element. key)
635
+ }
636
+ guard let value = element. value. as ( StringLiteralExprSyntax . self) else {
637
+ throw DiagnosticError ( " expected a string literal, got ' \( element. value) ' " , node: element. value)
638
+ }
639
+ dict [ key. representedLiteralValue!] = value. representedLiteralValue!
640
+ }
641
+ default :
642
+ throw DiagnosticError ( " unknown dictionary literal " , node: dictExpr)
643
+ }
644
+ return dict
645
+ }
646
+
647
+ static func parseCxxSpanParams(
648
+ _ signature: FunctionSignatureSyntax ,
649
+ _ typeMappings: [ String : String ] ?
650
+ ) throws -> [ ParamInfo ] {
651
+ guard let typeMappings else {
652
+ return [ ]
653
+ }
654
+ var result : [ ParamInfo ] = [ ]
655
+ for (idx, param) in signature. parameterClause. parameters. enumerated ( ) {
656
+ let typeName = try getTypeName ( param. type) . text;
657
+ if let desugaredType = typeMappings [ typeName] {
658
+ if desugaredType. starts ( with: " span " ) {
659
+ result. append ( CxxSpan ( pointerIndex: idx + 1 , nonescaping: false ,
660
+ original: param, typeMappings: typeMappings) )
661
+ }
662
+ }
663
+ }
664
+ return result
665
+ }
666
+
560
667
static func parseMacroParam(
561
668
_ paramAST: LabeledExprSyntax , _ signature: FunctionSignatureSyntax ,
562
669
nonescapingPointers: inout Set < Int >
@@ -651,11 +758,20 @@ public struct SwiftifyImportMacro: PeerMacro {
651
758
}
652
759
653
760
let argumentList = node. arguments!. as ( LabeledExprListSyntax . self) !
761
+ var arguments = Array < LabeledExprSyntax > ( argumentList)
762
+ let typeMappings = try parseTypeMappingParam ( arguments. last)
763
+ if typeMappings != nil {
764
+ arguments = arguments. dropLast ( )
765
+ }
654
766
var nonescapingPointers = Set < Int > ( )
655
- var parsedArgs = try argumentList . compactMap {
767
+ var parsedArgs = try arguments . compactMap {
656
768
try parseMacroParam ( $0, funcDecl. signature, nonescapingPointers: & nonescapingPointers)
657
769
}
770
+ parsedArgs. append ( contentsOf: try parseCxxSpanParams ( funcDecl. signature, typeMappings) )
658
771
setNonescapingPointers ( & parsedArgs, nonescapingPointers)
772
+ parsedArgs = parsedArgs. filter {
773
+ !( $0 is CxxSpan ) || ( $0 as! CxxSpan ) . nonescaping
774
+ }
659
775
try checkArgs ( parsedArgs, funcDecl)
660
776
let baseBuilder = FunctionCallBuilder ( funcDecl)
661
777
0 commit comments