Skip to content

Get Windows to actually compile and pass some tests #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion FlyingSocks/Sources/Mutex.swift
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ extension Mutex {
}

func tryLock() -> Bool {
TryAcquireSRWLockExclusive(_lock)
TryAcquireSRWLockExclusive(_lock) != 0
}
}
}
Expand Down
136 changes: 125 additions & 11 deletions FlyingSocks/Sources/Socket+WinSock2.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,17 @@

#if canImport(WinSDK)
import WinSDK.WinSock2
import Foundation

let O_NONBLOCK = Int32(1)
let F_SETFL = Int32(1)
let F_GETFL = Int32(1)
var errno: Int32 { WSAGetLastError() }
let EWOULDBLOCK = WSAEWOULDBLOCK
let EBADF = WSA_INVALID_HANDLE
let EBADF = WSAENOTSOCK
let EINPROGRESS = WSAEINPROGRESS
let EISCONN = WSAEISCONN
public typealias sa_family_t = UInt8
public typealias sa_family_t = ADDRESS_FAMILY

public extension Socket {
typealias FileDescriptorType = UInt64
Expand All @@ -59,10 +60,9 @@ extension Socket {
static let datagram = Int32(SOCK_DGRAM)
static let in_addr_any = WinSDK.in_addr()
static let ipproto_ip = Int32(IPPROTO_IP)
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
static let ipproto_ipv6 = Int32(IPPROTO_IPV6.rawValue)
static let ip_pktinfo = Int32(IP_PKTINFO)
static let ipv6_pktinfo = Int32(IPV6_PKTINFO)
static let ipv6_recvpktinfo = Int32(IPV6_RECVPKTINFO)

static func makeAddressINET(port: UInt16) -> WinSDK.sockaddr_in {
WinSDK.sockaddr_in(
Expand Down Expand Up @@ -107,6 +107,7 @@ extension Socket {
}

static func fcntl(_ fd: FileDescriptorType, _ cmd: Int32) -> Int32 {
guard fd != INVALID_SOCKET else { return -1 }
return 0
}

Expand All @@ -123,7 +124,110 @@ extension Socket {
}

static func socketpair(_ domain: Int32, _ type: Int32, _ protocol: Int32) -> (FileDescriptorType, FileDescriptorType) {
(-1, -1) // no supported
guard domain == AF_UNIX else { return (INVALID_SOCKET, INVALID_SOCKET) }
func makeTempUnixPath() -> URL {
let tempURL = FileManager.default.temporaryDirectory.appendingPathComponent("socketpair_\(UUID().uuidString.prefix(8)).sock", isDirectory: false)
try? FileManager.default.removeItem(at: tempURL)
return tempURL
}

if type == SOCK_STREAM {
let tempURL = makeTempUnixPath()
defer { try? FileManager.default.removeItem(at: tempURL) }

let listener = socket(domain, type, `protocol`)

guard listener != INVALID_SOCKET else { return (INVALID_SOCKET, INVALID_SOCKET) }

let addr = makeAddressUnix(path: tempURL.path)

let bindListenerResult = addr.withSockAddr {
bind(listener, $0, addr.size)
}

guard bindListenerResult == 0 else { return (INVALID_SOCKET, INVALID_SOCKET) }

guard listen(listener, 1) == 0 else {
_ = close(listener)
return (INVALID_SOCKET, INVALID_SOCKET)
}

let connector = socket(domain, type, `protocol`)

guard connector != INVALID_SOCKET else {
_ = close(listener)
return (INVALID_SOCKET, INVALID_SOCKET)
}

let connectResult = addr.withSockAddr { connect(connector, $0, addr.size) == 0 }

guard connectResult else {
_ = close(listener)
_ = close(connector)
return (INVALID_SOCKET, INVALID_SOCKET)
}

let acceptor = accept(listener, nil, nil)
guard acceptor != INVALID_SOCKET else {
_ = close(listener)
_ = close(connector)
return (INVALID_SOCKET, INVALID_SOCKET)
}

_ = close(listener)

return (connector, acceptor)
} else if type == SOCK_DGRAM {
return (INVALID_SOCKET, INVALID_SOCKET)
// unsupported at this time: https://github.com/microsoft/WSL/issues/5272
// let tempURL1 = makeTempUnixPath()
// let tempURL2 = makeTempUnixPath()
// guard FileManager.default.createFile(atPath: tempURL1.path, contents: nil) else { return (INVALID_SOCKET, INVALID_SOCKET) }
// guard FileManager.default.createFile(atPath: tempURL2.path, contents: nil) else { return (INVALID_SOCKET, INVALID_SOCKET) }

// defer { try? FileManager.default.removeItem(at: tempURL1) }
// defer { try? FileManager.default.removeItem(at: tempURL2) }

// let socket1 = socket(domain, type, `protocol`)
// let socket2 = socket(domain, type, `protocol`)

// guard socket1 != INVALID_SOCKET, socket2 != INVALID_SOCKET else {
// if socket1 != INVALID_SOCKET { _ = close(socket1) }
// if socket2 != INVALID_SOCKET { _ = close(socket2) }
// return (INVALID_SOCKET, INVALID_SOCKET)
// }

// let addr1 = makeAddressUnix(path: tempURL1.path)
// let addr2 = makeAddressUnix(path: tempURL2.path)

// guard addr1.withSockAddr({ bind(socket1, $0, addr1.size) }) == 0 else {
// _ = close(socket1)
// _ = close(socket2)
// return (INVALID_SOCKET, INVALID_SOCKET)
// }

// guard addr2.withSockAddr({ bind(socket2, $0, addr2.size) }) == 0 else {
// _ = close(socket1)
// _ = close(socket2)
// return (INVALID_SOCKET, INVALID_SOCKET)
// }

// guard addr2.withSockAddr({ connect(socket1, $0, addr2.size) }) == 0 else {
// _ = close(socket1)
// _ = close(socket2)
// return (INVALID_SOCKET, INVALID_SOCKET)
// }

// guard addr1.withSockAddr({ connect(socket2, $0, addr1.size) }) == 0 else {
// _ = close(socket1)
// _ = close(socket2)
// return (INVALID_SOCKET, INVALID_SOCKET)
// }

// return (socket1, socket2)
} else {
return (INVALID_SOCKET, INVALID_SOCKET)
}
}

static func setsockopt(_ fd: FileDescriptorType, _ level: Int32, _ name: Int32,
Expand Down Expand Up @@ -196,19 +300,29 @@ extension Socket {
}

static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer<sockaddr>!, _ len: UnsafeMutablePointer<socklen_t>!) -> Int {
WinSDK.recvfrom(fd, buffer, nbyte, flags, addr, len)
Int(WinSDK.recvfrom(fd, buffer, Int32(nbyte), flags, addr, len))
}

static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
WinSDK.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
Int(WinSDK.sendto(fd, buffer, Int32(nbyte), flags, destaddr, destlen))
}
}

static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
WinSDK.recvmsg(fd, message, flags)
public extension in_addr {
var s_addr: UInt32 {
get {
S_un.S_addr
} set {
S_un.S_addr = newValue
}
}
}

static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
WinSDK.sendmsg(fd, message, flags)
private extension URL {
var fileSystemRepresentation: String {
withUnsafeFileSystemRepresentation {
String(cString: $0!)
}
}
}

Expand Down
14 changes: 13 additions & 1 deletion FlyingSocks/Sources/Socket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
//

#if canImport(WinSDK)
import WinSDK.WinSock2
@_exported import WinSDK.WinSock2
#elseif canImport(Android)
@_exported import Android
#endif
Expand Down Expand Up @@ -126,8 +126,10 @@ public struct Socket: Sendable, Hashable {
switch domain {
case AF_INET:
try setValue(true, for: .packetInfoIP)
#if !canImport(WinSDK)
case AF_INET6:
try setValue(true, for: .packetInfoIPv6)
#endif
default:
return
}
Expand Down Expand Up @@ -566,9 +568,11 @@ public extension SocketOption where Self == BoolSocketOption {
BoolSocketOption(level: Socket.ipproto_ip, name: Socket.ip_pktinfo)
}

#if !canImport(WinSDK)
static var packetInfoIPv6: Self {
BoolSocketOption(level: Socket.ipproto_ipv6, name: Socket.ipv6_recvpktinfo)
}
#endif

#if canImport(Darwin)
// Prevents SIG_TRAP when app is paused / running in background.
Expand Down Expand Up @@ -597,9 +601,17 @@ package extension Socket {

static func makePair(flags: Flags? = nil, type: SocketType = .stream) throws -> (Socket, Socket) {
let (file1, file2) = Socket.socketpair(AF_UNIX, type.rawValue, 0)

#if canImport(WinSDK)
guard file1 != INVALID_SOCKET, file2 != INVALID_SOCKET else {
throw SocketError.makeFailed("SocketPair")
}
#else
guard file1 > -1, file2 > -1 else {
throw SocketError.makeFailed("SocketPair")
}
#endif

let s1 = Socket(file: .init(rawValue: file1))
let s2 = Socket(file: .init(rawValue: file2))

Expand Down
9 changes: 8 additions & 1 deletion FlyingSocks/Tests/FileManager+TemporaryFile.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@ extension FileManager {
let dirPath = temporaryDirectory.appendingPathComponent("FlyingSocks.XXXXXX")
return dirPath.withUnsafeFileSystemRepresentation { maybePath in
guard let path = maybePath else { return nil }
var mutablePath = Array(repeating: Int8(0), count: Int(PATH_MAX))

#if canImport(WinSDK)
let pathMax = Int(MAX_PATH)
#else
let pathMax = Int(PATH_MAX)
#endif

var mutablePath = Array(repeating: Int8(0), count: pathMax)
mutablePath.withUnsafeMutableBytes { mutablePathBufferPtr in
mutablePathBufferPtr.baseAddress!.copyMemory(
from: path, byteCount: Int(strlen(path)) + 1)
Expand Down
5 changes: 5 additions & 0 deletions FlyingSocks/Tests/SocketErrorTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ struct SocketErrorTests {

@Test
func socketError_makeFailed() {
#if canImport(WinSDK)
WSASetLastError(EIO)
#else
errno = EIO
#endif

let socketError = SocketError.makeFailed("unit-test")
switch socketError {
case let .failed(type: type, errno: socketErrno, message: message):
Expand Down
2 changes: 1 addition & 1 deletion FlyingSocks/Tests/SocketPool+PollTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ private extension Poll {
}

private extension pollfd {
static func make(fd: Int32 = 0,
static func make(fd: Socket.FileDescriptorType = 0,
events: Int32 = POLLIN,
revents: Int32 = POLLIN) -> Self {
.init(fd: fd, events: Int16(events), revents: Int16(revents))
Expand Down
Loading