@@ -54,11 +54,6 @@ actor LiveSessionService {
5454 private let jsonEncoder = JSONEncoder ( )
5555 private let jsonDecoder = JSONDecoder ( )
5656
57- /// Task that doesn't complete until the server sends a setupComplete message.
58- ///
59- /// Used to hold off on sending messages until the server is ready.
60- private var setupTask : Task < Void , Error >
61-
6257 /// Long running task that that wraps around the websocket, propogating messages through the
6358 /// public stream.
6459 private var responsesTask : Task < Void , Never > ?
@@ -87,11 +82,9 @@ actor LiveSessionService {
8782 self . toolConfig = toolConfig
8883 self . systemInstruction = systemInstruction
8984 self . requestOptions = requestOptions
90- setupTask = Task { }
9185 }
9286
9387 deinit {
94- setupTask. cancel ( )
9588 responsesTask? . cancel ( )
9689 messageQueueTask? . cancel ( )
9790 webSocket? . disconnect ( )
@@ -114,29 +107,20 @@ actor LiveSessionService {
114107 ///
115108 /// Seperated into its own function to make it easier to surface a way to call it seperately when
116109 /// resuming the same session.
110+ ///
111+ /// This function will yield until the websocket is ready to communicate with the client.
117112 func connect( ) async throws {
118113 close ( )
119- // we launch the setup task in a seperate task to allow us to cancel it via close
120- setupTask = Task { [ weak self] in
121- // we need a continuation to surface that the setup is complete, while still allowing us to
122- // listen to the server
123- try await withCheckedThrowingContinuation { setupContinuation in
124- // nested task so we can use await
125- Task { [ weak self] in
126- guard let self else { return }
127- await self . listenToServer ( setupContinuation)
128- }
129- }
130- }
131114
132- try await setupTask. value
115+ let stream = try await setupWebsocket ( )
116+ try await waitForSetupComplete ( stream: stream)
117+ spawnMessageTasks ( stream: stream)
133118 }
134119
135120 /// Cancel any running tasks and close the websocket.
136121 ///
137122 /// This method is idempotent; if it's already ran once, it will effectively be a no-op.
138123 func close( ) {
139- setupTask. cancel ( )
140124 responsesTask? . cancel ( )
141125 messageQueueTask? . cancel ( )
142126 webSocket? . disconnect ( )
@@ -146,38 +130,19 @@ actor LiveSessionService {
146130 messageQueueTask = nil
147131 }
148132
149- /// Start a fresh websocket to the backend, and listen for responses .
133+ /// Performs the initial setup procedure for the model .
150134 ///
151- /// Will hold off on sending any messages until the server sends a setupComplete message.
135+ /// The setup procedure with the model follows the procedure:
152136 ///
153- /// Will also close out the old websocket and the previous long running tasks.
154- private func listenToServer( _ setupComplete: CheckedContinuation < Void , any Error > ) async {
155- do {
156- webSocket = try await createWebsocket ( )
157- } catch {
158- let error = LiveSessionSetupError ( underlyingError: error)
159- close ( )
160- setupComplete. resume ( throwing: error)
161- return
162- }
163-
137+ /// - Client sends `BidiGenerateContentSetup`
138+ /// - Server sends back `BidiGenerateContentSetupComplete` when it's ready
139+ ///
140+ /// This function will yield until the setup is complete.
141+ private func waitForSetupComplete( stream: MappedStream <
142+ URLSessionWebSocketTask . Message ,
143+ Data
144+ > ) async throws {
164145 guard let webSocket else { return }
165- let stream = webSocket. connect ( )
166-
167- var resumed = false
168-
169- // remove the uncommon (and unexpected) responses from the stream, to make normal path cleaner
170- let dataStream = stream. compactMap { ( message: URLSessionWebSocketTask . Message ) -> Data ? in
171- switch message {
172- case let . string( string) :
173- AILog . error ( code: . liveSessionUnexpectedResponse, " Unexpected string response: \( string) " )
174- case let . data( data) :
175- return data
176- @unknown default :
177- AILog . error ( code: . liveSessionUnexpectedResponse, " Unknown message received: \( message) " )
178- }
179- return nil
180- }
181146
182147 do {
183148 let setup = BidiGenerateContentSetup (
@@ -194,54 +159,87 @@ actor LiveSessionService {
194159 } catch {
195160 let error = LiveSessionSetupError ( underlyingError: error)
196161 close ( )
197- setupComplete. resume ( throwing: error)
198- return
162+ throw error
199163 }
200164
201- responsesTask = Task {
202- do {
203- for try await message in dataStream {
204- let response : BidiGenerateContentServerMessage
205- do {
206- response = try jsonDecoder. decode (
207- BidiGenerateContentServerMessage . self,
208- from: message
209- )
210- } catch {
211- // only log the json if it wasn't a decoding error, but an unsupported message type
212- if error is InvalidMessageTypeError {
213- AILog . error (
214- code: . liveSessionUnsupportedMessage,
215- " The server sent a message that we don't currently have a mapping for. "
216- )
165+ do {
166+ for try await message in stream {
167+ let response = try decodeServerMessage ( message)
168+ if case . setupComplete = response. messageType {
169+ break
170+ } else {
171+ AILog . error (
172+ code: . liveSessionUnexpectedResponse,
173+ " The model sent us a message that wasn't a setup complete: \( response) "
174+ )
175+ }
176+ }
177+ } catch {
178+ if let error = mapWebsocketError ( error) {
179+ close ( )
180+ throw error
181+ }
182+ // the user called close while setup was running
183+ // this can't currently happen, but could when we add automatic session resumption
184+ // in such case, we don't want to raise an error. this log is more-so to catch any edge cases
185+ AILog . debug (
186+ code: . liveSessionClosedDuringSetup,
187+ " The live session was closed before setup could complete: \( error. localizedDescription) "
188+ )
189+ }
190+ }
217191
218- AILog . debug (
219- code: . liveSessionUnsupportedMessagePayload,
220- message. encodeToJsonString ( ) ?? " \( message) "
221- )
222- }
192+ /// Performs the initial setup procedure for a websocket.
193+ ///
194+ /// This includes creating the websocket url and connecting it.
195+ ///
196+ /// - Returns: A stream of `Data` frames from the websocket.
197+ private func setupWebsocket( ) async throws
198+ -> MappedStream < URLSessionWebSocketTask . Message , Data > {
199+ do {
200+ let webSocket = try await createWebsocket ( )
201+ self . webSocket = webSocket
202+
203+ let stream = webSocket. connect ( )
204+
205+ // remove the uncommon (and unexpected) frames from the stream, to make normal path cleaner
206+ return stream. compactMap { message in
207+ switch message {
208+ case let . string( string) :
209+ AILog . error ( code: . liveSessionUnexpectedResponse, " Unexpected string response: \( string) " )
210+ case let . data( data) :
211+ return data
212+ @unknown default :
213+ AILog . error ( code: . liveSessionUnexpectedResponse, " Unknown message received: \( message) " )
214+ }
215+ return nil
216+ }
217+ } catch {
218+ let error = LiveSessionSetupError ( underlyingError: error)
219+ close ( )
220+ throw error
221+ }
222+ }
223223
224- let error = LiveSessionUnsupportedMessageError ( underlyingError: error)
225- // if we've already finished setting up, then only surface the error through responses
226- // otherwise, make the setup task error as well
227- if !resumed {
228- setupComplete. resume ( throwing: error)
229- }
230- throw error
231- }
224+ /// Spawn tasks for interacting with the model.
225+ ///
226+ /// The following tasks will be spawned:
227+ ///
228+ /// - `responsesTask`: Listen to messages from the server and yield them through `responses`.
229+ /// - `messageQueueTask`: Listen to messages from the client and send them through the websocket.
230+ private func spawnMessageTasks( stream: MappedStream < URLSessionWebSocketTask . Message , Data > ) {
231+ guard let webSocket else { return }
232+
233+ responsesTask = Task {
234+ do {
235+ for try await message in stream {
236+ let response = try decodeServerMessage ( message)
232237
233238 if case . setupComplete = response. messageType {
234- if resumed {
235- AILog . debug (
236- code: . duplicateLiveSessionSetupComplete,
237- " Setup complete was received multiple times; this may be a bug in the model. "
238- )
239- } else {
240- // calling resume multiple times is an error in swift, so we catch multiple calls
241- // to avoid causing any issues due to model quirks
242- resumed = true
243- setupComplete. resume ( )
244- }
239+ AILog . debug (
240+ code: . duplicateLiveSessionSetupComplete,
241+ " Setup complete was received multiple times; this may be a bug in the model. "
242+ )
245243 } else if let liveMessage = LiveServerMessage ( from: response) {
246244 if case let . goingAwayNotice( message) = liveMessage. payload {
247245 // TODO: (b/444045023) When auto session resumption is enabled, call `connect` again
@@ -255,21 +253,7 @@ actor LiveSessionService {
255253 }
256254 }
257255 } catch {
258- if let error = error as? WebSocketClosedError {
259- // only raise an error if the session didn't close normally (ie; the user calling close)
260- if error. closeCode != . goingAway {
261- let closureError : Error
262- if let error = error. underlyingError as? NSError , error. domain == NSURLErrorDomain,
263- error. code == NSURLErrorNetworkConnectionLost {
264- closureError = LiveSessionLostConnectionError ( underlyingError: error)
265- } else {
266- closureError = LiveSessionUnexpectedClosureError ( underlyingError: error)
267- }
268- close ( )
269- responseContinuation. finish ( throwing: closureError)
270- }
271- } else {
272- // an error occurred outside the websocket, so it's likely not closed
256+ if let error = mapWebsocketError ( error) {
273257 close ( )
274258 responseContinuation. finish ( throwing: error)
275259 }
@@ -278,22 +262,7 @@ actor LiveSessionService {
278262
279263 messageQueueTask = Task {
280264 for await message in messageQueue {
281- // we don't propogate errors, since those are surfaced in the responses stream
282- guard let _ = try ? await setupTask. value else {
283- break
284- }
285-
286- let data : Data
287- do {
288- data = try jsonEncoder. encode ( message)
289- } catch {
290- AILog . error ( code: . liveSessionFailedToEncodeClientMessage, error. localizedDescription)
291- AILog . debug (
292- code: . liveSessionFailedToEncodeClientMessagePayload,
293- String ( describing: message)
294- )
295- continue
296- }
265+ guard let data = encodeClientMessage ( message) else { continue }
297266
298267 do {
299268 try await webSocket. send ( . data( data) )
@@ -304,6 +273,75 @@ actor LiveSessionService {
304273 }
305274 }
306275
276+ /// Checks if an error should be propogated up, and maps it accordingly.
277+ ///
278+ /// Some errors have public api alternatives. This function will ensure they're mapped
279+ /// accordingly.
280+ private func mapWebsocketError( _ error: Error ) -> Error ? {
281+ if let error = error as? WebSocketClosedError {
282+ // only raise an error if the session didn't close normally (ie; the user calling close)
283+ if error. closeCode == . goingAway {
284+ return nil
285+ }
286+
287+ let closureError : Error
288+
289+ if let error = error. underlyingError as? NSError , error. domain == NSURLErrorDomain,
290+ error. code == NSURLErrorNetworkConnectionLost {
291+ closureError = LiveSessionLostConnectionError ( underlyingError: error)
292+ } else {
293+ closureError = LiveSessionUnexpectedClosureError ( underlyingError: error)
294+ }
295+
296+ return closureError
297+ }
298+
299+ return error
300+ }
301+
302+ /// Decodes a message from the server's websocket into a valid `BidiGenerateContentServerMessage`.
303+ ///
304+ /// Will throw an error if decoding fails.
305+ private func decodeServerMessage( _ message: Data ) throws -> BidiGenerateContentServerMessage {
306+ do {
307+ return try jsonDecoder. decode (
308+ BidiGenerateContentServerMessage . self,
309+ from: message
310+ )
311+ } catch {
312+ // only log the json if it wasn't a decoding error, but an unsupported message type
313+ if error is InvalidMessageTypeError {
314+ AILog . error (
315+ code: . liveSessionUnsupportedMessage,
316+ " The server sent a message that we don't currently have a mapping for. "
317+ )
318+ AILog . debug (
319+ code: . liveSessionUnsupportedMessagePayload,
320+ message. encodeToJsonString ( ) ?? " \( message) "
321+ )
322+ }
323+
324+ throw LiveSessionUnsupportedMessageError ( underlyingError: error)
325+ }
326+ }
327+
328+ /// Encodes a message from the client into `Data` that can be sent through a websocket data frame.
329+ ///
330+ /// Will return `nil` if decoding fails, and log an error describing why.
331+ private func encodeClientMessage( _ message: BidiGenerateContentClientMessage ) -> Data ? {
332+ do {
333+ return try jsonEncoder. encode ( message)
334+ } catch {
335+ AILog . error ( code: . liveSessionFailedToEncodeClientMessage, error. localizedDescription)
336+ AILog . debug (
337+ code: . liveSessionFailedToEncodeClientMessagePayload,
338+ String ( describing: message)
339+ )
340+ }
341+
342+ return nil
343+ }
344+
307345 /// Creates a websocket pointing to the backend.
308346 ///
309347 /// Will apply the required app check and auth headers, as the backend expects them.
@@ -392,3 +430,8 @@ private extension String {
392430 }
393431 }
394432}
433+
434+ /// Helper alias for a compact mapped throwing stream.
435+ ///
436+ /// We use this to make signatures easier to read, since we can't support `AsyncSequence` quite yet.
437+ private typealias MappedStream < T, V> = AsyncCompactMapSequence < AsyncThrowingStream < T , any Error > , V >
0 commit comments