Skip to content

Commit cf3be03

Browse files
committed
Adapt netty tcp transport implementation to the latest change in transport API and improve the correctness of the implementation
1 parent ba20297 commit cf3be03

File tree

8 files changed

+246
-273
lines changed

8 files changed

+246
-273
lines changed

rsocket-transports/netty-internal/api/rsocket-transport-netty-internal.api

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
public final class io/rsocket/kotlin/transport/netty/internal/CoroutinesKt {
2-
public static final fun awaitChannel (Lio/netty/channel/ChannelFuture;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
32
public static final fun awaitFuture (Lio/netty/util/concurrent/Future;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
4-
public static final fun callOnCancellation (Lkotlinx/coroutines/CoroutineScope;Lkotlin/jvm/functions/Function1;)V
3+
public static final fun join (Lio/netty/util/concurrent/Future;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
4+
public static final fun shutdownOnCancellation (Lkotlinx/coroutines/CoroutineScope;[Lio/netty/channel/EventLoopGroup;)V
55
}
66

77
public final class io/rsocket/kotlin/transport/netty/internal/IoKt {

rsocket-transports/netty-internal/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/internal/coroutines.kt

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2015-2024 the original author or authors.
2+
* Copyright 2015-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ package io.rsocket.kotlin.transport.netty.internal
1818

1919
import io.netty.channel.*
2020
import io.netty.util.concurrent.*
21+
import io.rsocket.kotlin.internal.io.*
2122
import kotlinx.coroutines.*
2223
import kotlin.coroutines.*
2324

@@ -34,25 +35,24 @@ public suspend inline fun <T> Future<T>.awaitFuture(): T = suspendCancellableCor
3435
}
3536
}
3637

37-
public suspend fun ChannelFuture.awaitChannel(): Channel {
38+
public suspend inline fun Future<*>.join(): Unit = suspendCancellableCoroutine { cont ->
39+
addListener { cont.resume(Unit) }
40+
cont.invokeOnCancellation { cancel(true) }
41+
}
42+
43+
public suspend inline fun <reified T : Channel> ChannelFuture.awaitChannel(): T {
3844
awaitFuture()
39-
return channel()
45+
return channel() as T
4046
}
4147

42-
// it should be used only for cleanup and so should not really block, only suspend
43-
public inline fun CoroutineScope.callOnCancellation(crossinline block: suspend () -> Unit) {
48+
public fun CoroutineScope.shutdownOnCancellation(vararg groups: EventLoopGroup) {
4449
launch(Dispatchers.Unconfined) {
4550
try {
4651
awaitCancellation()
47-
} catch (cause: Throwable) {
48-
withContext(NonCancellable) {
49-
try {
50-
block()
51-
} catch (suppressed: Throwable) {
52-
cause.addSuppressed(suppressed)
53-
}
52+
} finally {
53+
nonCancellable {
54+
groups.forEach { it.shutdownGracefully().join() }
5455
}
55-
throw cause
5656
}
5757
}
5858
}

rsocket-transports/netty-internal/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/internal/io.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2015-2024 the original author or authors.
2+
* Copyright 2015-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -31,12 +31,13 @@ public fun ByteBuf.toBuffer(): Buffer {
3131
toRead
3232
}
3333
}
34+
release()
3435
return buffer
3536
}
3637

3738
@OptIn(UnsafeIoApi::class)
3839
public fun Buffer.toByteBuf(allocator: ByteBufAllocator): ByteBuf {
39-
val nettyBuffer = allocator.buffer(size.toInt()) // TODO: length
40+
val nettyBuffer = allocator.directBuffer(size.toInt()) // TODO: length
4041
while (!exhausted()) {
4142
UnsafeBufferOperations.readFromHead(this) { bytes, start, end ->
4243
nettyBuffer.writeBytes(bytes, start, end - start)

rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport.kt

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2015-2024 the original author or authors.
2+
* Copyright 2015-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -91,31 +91,25 @@ private class NettyTcpClientTransportBuilderImpl : NettyTcpClientTransportBuilde
9191
}
9292

9393
return NettyTcpClientTransportImpl(
94-
coroutineContext = context.supervisorContext() + bootstrap.config().group().asCoroutineDispatcher(),
95-
sslContext = sslContext,
94+
coroutineContext = context.supervisorContext() + Dispatchers.Default,
9695
bootstrap = bootstrap,
97-
manageBootstrap = manageEventLoopGroup
98-
)
96+
sslContext = sslContext,
97+
).also {
98+
if (manageEventLoopGroup) it.shutdownOnCancellation(bootstrap.config().group())
99+
}
99100
}
100101
}
101102

102103
private class NettyTcpClientTransportImpl(
103104
override val coroutineContext: CoroutineContext,
104-
private val sslContext: SslContext?,
105105
private val bootstrap: Bootstrap,
106-
manageBootstrap: Boolean,
106+
private val sslContext: SslContext?,
107107
) : NettyTcpClientTransport {
108-
init {
109-
if (manageBootstrap) callOnCancellation {
110-
bootstrap.config().group().shutdownGracefully().awaitFuture()
111-
}
112-
}
113-
114-
override fun target(remoteAddress: SocketAddress): NettyTcpClientTargetImpl = NettyTcpClientTargetImpl(
108+
override fun target(remoteAddress: SocketAddress): RSocketClientTarget = NettyTcpClientTargetImpl(
115109
coroutineContext = coroutineContext.supervisorContext(),
116110
bootstrap = bootstrap,
117111
sslContext = sslContext,
118-
remoteAddress = remoteAddress
112+
remoteAddress = remoteAddress,
119113
)
120114

121115
override fun target(host: String, port: Int): RSocketClientTarget = target(InetSocketAddress(host, port))
@@ -124,19 +118,27 @@ private class NettyTcpClientTransportImpl(
124118
@OptIn(RSocketTransportApi::class)
125119
private class NettyTcpClientTargetImpl(
126120
override val coroutineContext: CoroutineContext,
127-
private val bootstrap: Bootstrap,
128-
private val sslContext: SslContext?,
129-
private val remoteAddress: SocketAddress,
121+
bootstrap: Bootstrap,
122+
sslContext: SslContext?,
123+
remoteAddress: SocketAddress,
130124
) : RSocketClientTarget {
131-
@RSocketTransportApi
132-
override fun connectClient(handler: RSocketConnectionHandler): Job = launch {
133-
bootstrap.clone().handler(
125+
private val bootstrap = bootstrap.clone()
126+
.handler(
134127
NettyTcpConnectionInitializer(
128+
parentContext = coroutineContext,
135129
sslContext = sslContext,
136-
remoteAddress = remoteAddress as? InetSocketAddress,
137-
handler = handler,
138-
coroutineContext = coroutineContext
130+
onConnection = null
139131
)
140-
).connect(remoteAddress).awaitFuture()
132+
)
133+
.remoteAddress(remoteAddress)
134+
135+
@RSocketTransportApi
136+
override suspend fun connectClient(): RSocketConnection {
137+
currentCoroutineContext().ensureActive()
138+
coroutineContext.ensureActive()
139+
140+
val channel = bootstrap.connect().awaitChannel<Channel>()
141+
142+
return channel.attr(NettyTcpConnection.ATTRIBUTE).get()
141143
}
142144
}
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Copyright 2015-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.rsocket.kotlin.transport.netty.tcp
18+
19+
import io.netty.buffer.*
20+
import io.netty.channel.*
21+
import io.netty.channel.socket.*
22+
import io.netty.handler.codec.*
23+
import io.netty.handler.ssl.*
24+
import io.netty.util.*
25+
import io.rsocket.kotlin.internal.io.*
26+
import io.rsocket.kotlin.transport.*
27+
import io.rsocket.kotlin.transport.internal.*
28+
import io.rsocket.kotlin.transport.netty.internal.*
29+
import kotlinx.coroutines.*
30+
import kotlinx.coroutines.channels.*
31+
import kotlinx.coroutines.channels.Channel
32+
import kotlinx.io.*
33+
import kotlin.coroutines.*
34+
35+
@RSocketTransportApi
36+
internal class NettyTcpConnection(
37+
parentContext: CoroutineContext,
38+
private val channel: DuplexChannel,
39+
) : RSocketSequentialConnection, ChannelInboundHandlerAdapter() {
40+
41+
private val outboundQueue = PrioritizationFrameQueue()
42+
private val inbound = bufferChannel(Channel.UNLIMITED)
43+
44+
override val coroutineContext: CoroutineContext = parentContext.childContext() + channel.eventLoop().asCoroutineDispatcher()
45+
46+
init {
47+
@OptIn(DelicateCoroutinesApi::class)
48+
launch(start = CoroutineStart.ATOMIC) {
49+
val outboundJob = launch(start = CoroutineStart.ATOMIC) {
50+
nonCancellable {
51+
try {
52+
while (true) {
53+
// we write all available frames here, and only after it flush
54+
// in this case, if there are several buffered frames we can send them in one go
55+
// avoiding unnecessary flushes
56+
writeBuffer(outboundQueue.dequeueFrame() ?: break)
57+
while (true) writeBuffer(outboundQueue.tryDequeueFrame() ?: break)
58+
channel.flush()
59+
}
60+
} finally {
61+
outboundQueue.cancel()
62+
channel.shutdownOutput().awaitFuture()
63+
}
64+
}
65+
}
66+
try {
67+
awaitCancellation()
68+
} finally {
69+
nonCancellable {
70+
outboundQueue.close()
71+
inbound.cancel()
72+
channel.shutdownInput().awaitFuture()
73+
outboundJob.join()
74+
channel.close().awaitFuture()
75+
}
76+
}
77+
}
78+
}
79+
80+
override fun channelInactive(ctx: ChannelHandlerContext) {
81+
cancel("Channel is not active")
82+
ctx.fireChannelInactive()
83+
}
84+
85+
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable?) {
86+
cancel("exceptionCaught", cause)
87+
}
88+
89+
override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any?) {
90+
if (evt === ChannelInputShutdownEvent.INSTANCE) inbound.close()
91+
super.userEventTriggered(ctx, evt)
92+
}
93+
94+
override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
95+
val buffer = (msg as ByteBuf).toBuffer()
96+
if (inbound.trySend(buffer).isFailure) buffer.clear()
97+
}
98+
99+
override suspend fun sendFrame(streamId: Int, frame: Buffer) {
100+
return outboundQueue.enqueueFrame(streamId, frame)
101+
}
102+
103+
override suspend fun receiveFrame(): Buffer? {
104+
inbound.tryReceive().onSuccess { return it }
105+
channel.read()
106+
return inbound.receiveCatching().getOrNull()
107+
}
108+
109+
private fun writeBuffer(buffer: Buffer) {
110+
channel.write(buffer.toByteBuf(channel.alloc()), channel.voidPromise())
111+
}
112+
113+
companion object {
114+
val ATTRIBUTE: AttributeKey<RSocketConnection> = AttributeKey.newInstance<RSocketConnection>("rsocket-tcp-connection")
115+
}
116+
}
117+
118+
@OptIn(RSocketTransportApi::class)
119+
internal class NettyTcpConnectionInitializer(
120+
private val parentContext: CoroutineContext,
121+
private val sslContext: SslContext?,
122+
private val onConnection: ((RSocketConnection) -> Unit)?,
123+
) : ChannelInitializer<DuplexChannel>() {
124+
override fun initChannel(channel: DuplexChannel) {
125+
channel.config().isAutoRead = false
126+
127+
val connection = NettyTcpConnection(parentContext, channel)
128+
channel.attr(NettyTcpConnection.ATTRIBUTE).set(connection)
129+
130+
if (sslContext != null) {
131+
channel.pipeline().addLast("ssl", sslContext.newHandler(channel.alloc()))
132+
}
133+
channel.pipeline().addLast(
134+
"rsocket-length-encoder",
135+
LengthFieldPrepender(
136+
/* lengthFieldLength = */ 3
137+
)
138+
)
139+
channel.pipeline().addLast(
140+
"rsocket-length-decoder",
141+
LengthFieldBasedFrameDecoder(
142+
/* maxFrameLength = */ Int.MAX_VALUE,
143+
/* lengthFieldOffset = */ 0,
144+
/* lengthFieldLength = */ 3,
145+
/* lengthAdjustment = */ 0,
146+
/* initialBytesToStrip = */ 3
147+
)
148+
)
149+
channel.pipeline().addLast(
150+
"rsocket-connection",
151+
connection
152+
)
153+
154+
onConnection?.invoke(connection)
155+
}
156+
}

0 commit comments

Comments
 (0)