1010import java .util .List ;
1111import java .util .Optional ;
1212import java .util .*;
13+ import java .util .concurrent .CompletableFuture ;
14+ import java .util .concurrent .Future ;
1315import java .util .function .BiConsumer ;
1416import java .util .function .Consumer ;
1517
@@ -40,7 +42,11 @@ public final class RealtimeTranscriber implements AutoCloseable {
4042 private final Consumer <Throwable > onError ;
4143 private final BiConsumer <Integer , String > onClose ;
4244 private final RealtimeMessageVisitor realtimeMessageVisitor ;
45+ private final Consumer <SessionInformation > onSessionInformation ;
4346 private WebSocket webSocket ;
47+ private SessionInformation sessionInformation ;
48+ private CompletableFuture <SessionInformation > sessionTerminatedFuture ;
49+ private boolean isConnected ;
4450
4551 private RealtimeTranscriber (
4652 String apiKey ,
@@ -55,6 +61,7 @@ private RealtimeTranscriber(
5561 Consumer <FinalTranscript > onFinalTranscript ,
5662 Consumer <RealtimeTranscript > onTranscript ,
5763 Consumer <Throwable > onError ,
64+ Consumer <SessionInformation > onSessionInformation ,
5865 BiConsumer <Integer , String > onClose ) {
5966 this .apiKey = apiKey ;
6067 this .token = token ;
@@ -68,6 +75,7 @@ private RealtimeTranscriber(
6875 this .onFinalTranscript = onFinalTranscript ;
6976 this .onTranscript = onTranscript ;
7077 this .onError = onError ;
78+ this .onSessionInformation = onSessionInformation ;
7179 this .onClose = onClose ;
7280 this .realtimeMessageVisitor = new RealtimeMessageVisitor ();
7381 }
@@ -83,6 +91,10 @@ public void connect() {
8391 if (disablePartialTranscripts ) {
8492 url += "&disable_partial_transcripts=true" ;
8593 }
94+
95+ // always set so it can be return from closeWithSessionTermination
96+ url += "&enable_extra_session_information=true" ;
97+
8698 if (wordBoost .isPresent () && !wordBoost .get ().isEmpty ()) {
8799 try {
88100 url += "&word_boost=" + ObjectMappers .JSON_MAPPER .writeValueAsString (wordBoost .get ());
@@ -144,15 +156,33 @@ public void configureEndUtteranceSilenceThreshold(int threshold) {
144156 ));
145157 }
146158
159+ public Future <SessionInformation > closeWithSessionTermination () {
160+ this .sessionTerminatedFuture = new CompletableFuture <SessionInformation >();
161+ this .webSocket .send ("{\" terminate_session\" :true}" );
162+ sessionTerminatedFuture .whenComplete ((sessionInformation1 , throwable ) -> this .closeSocket ());
163+ return this .sessionTerminatedFuture ;
164+ }
165+
147166 /**
148- * Closes the websocket connection.
167+ * Closes the websocket connection immediately, without waiting for session termination.
168+ * Use closeWithSessionTermination() if possible.
169+ *
170+ * @see #closeWithSessionTermination
171+ * Terminate the session, wait for session termination, and then close the connection.
149172 */
150173 @ Override
151174 public void close () {
152- boolean closed = this .webSocket .close (1000 , "Shutting down" );
153- if (!closed ) {
154- this .webSocket .cancel ();
175+ if (isConnected ) {
176+ this .webSocket .send ("{\" terminate_session\" :true}" );
155177 }
178+ this .closeSocket ();
179+ }
180+
181+ private void closeSocket () {
182+ if (webSocket == null ) return ;
183+ this .webSocket .close (1000 , "Shutting down" );
184+ this .webSocket .cancel ();
185+ this .webSocket = null ;
156186 }
157187
158188 public static RealtimeTranscriber .Builder builder () {
@@ -174,6 +204,7 @@ public static final class Builder {
174204 private Consumer <RealtimeTranscript > onTranscript ;
175205 private Consumer <Throwable > onError ;
176206 private BiConsumer <Integer , String > onClose ;
207+ private Consumer <SessionInformation > onSessionInformation ;
177208
178209 /**
179210 * Sets the AssemblyAI API key used to authenticate the RealtimeTranscriber
@@ -323,6 +354,19 @@ public RealtimeTranscriber.Builder onError(Consumer<Throwable> onError) {
323354 return this ;
324355 }
325356
357+ /**
358+ * Sets onSessionInformation
359+ *
360+ * @param onSessionInformation an event handler for the session information event.
361+ * This message is sent at the end of the session, before the SessionTerminated message.
362+ * Defaults to a noop.
363+ * @return this
364+ */
365+ public RealtimeTranscriber .Builder onSessionInformation (Consumer <SessionInformation > onSessionInformation ) {
366+ this .onSessionInformation = onSessionInformation ;
367+ return this ;
368+ }
369+
326370 /**
327371 * Sets onClose
328372 *
@@ -351,6 +395,7 @@ public RealtimeTranscriber build() {
351395 onFinalTranscript ,
352396 onTranscript ,
353397 onError ,
398+ onSessionInformation ,
354399 onClose );
355400 }
356401 }
@@ -364,6 +409,7 @@ public Listener(Consumer<Response> onOpen) {
364409
365410 @ Override
366411 public void onOpen (@ NotNull WebSocket webSocket , @ NotNull Response response ) {
412+ isConnected = true ;
367413 if (onOpen != null ) {
368414 onOpen .accept (response );
369415 }
@@ -372,12 +418,29 @@ public void onOpen(@NotNull WebSocket webSocket, @NotNull Response response) {
372418 @ Override
373419 public void onMessage (@ NotNull WebSocket webSocket , @ NotNull String text ) {
374420 try {
375- RealtimeMessage realtimeMessage = ObjectMappers .JSON_MAPPER .readValue (text , RealtimeMessage .class );
376- try {
377- realtimeMessage .visit (realtimeMessageVisitor );
378- } catch (IllegalStateException ignored ) {
379- // when a new message is added to the API, this should not throw an exception
421+ RealtimeBaseMessage baseMessage = ObjectMappers .parseOrThrow (text , RealtimeBaseMessage .class );
422+ MessageType messageType = baseMessage .getMessageType ();
423+ if (messageType == MessageType .SESSION_BEGINS ) {
424+ realtimeMessageVisitor .visit (
425+ ObjectMappers .JSON_MAPPER .readValue (text , SessionBegins .class )
426+ );
427+ } else if (messageType == MessageType .PARTIAL_TRANSCRIPT ) {
428+ realtimeMessageVisitor .visit (
429+ ObjectMappers .JSON_MAPPER .readValue (text , PartialTranscript .class )
430+ );
431+ } else if (messageType == MessageType .FINAL_TRANSCRIPT ) {
432+ realtimeMessageVisitor .visit (
433+ ObjectMappers .JSON_MAPPER .readValue (text , FinalTranscript .class )
434+ );
435+ } else if (messageType == MessageType .SESSION_INFORMATION ) {
436+ realtimeMessageVisitor .visit (
437+ ObjectMappers .JSON_MAPPER .readValue (text , SessionInformation .class )
438+ );
439+ } else if (messageType == MessageType .SESSION_TERMINATED ) {
440+ realtimeMessageVisitor .visit ((SessionTerminated ) null );
380441 }
442+ // Intentionally don't throw an exception for unknown message type.
443+ // New message types shouldn't cause this to break.
381444 } catch (JsonProcessingException e ) {
382445 if (onError == null ) return ;
383446 onError .accept (e );
@@ -386,6 +449,7 @@ public void onMessage(@NotNull WebSocket webSocket, @NotNull String text) {
386449
387450 @ Override
388451 public void onFailure (@ NotNull WebSocket webSocket , @ NotNull Throwable t , @ Nullable Response response ) {
452+ isConnected = false ;
389453 if (onError == null ) return ;
390454 onError .accept (t );
391455 }
@@ -399,6 +463,12 @@ public void onClosing(@NotNull WebSocket webSocket, int code, String reason) {
399463 onClose .accept (code , reason );
400464 super .onClosing (webSocket , code , reason );
401465 }
466+
467+ @ Override
468+ public void onClosed (@ NotNull WebSocket webSocket , int code , @ NotNull String reason ) {
469+ isConnected = false ;
470+ super .onClosed (webSocket , code , reason );
471+ }
402472 }
403473
404474 private final class RealtimeMessageVisitor implements RealtimeMessage .Visitor <Void > {
@@ -423,8 +493,20 @@ public Void visit(FinalTranscript value) {
423493 return null ;
424494 }
425495
496+ @ Override
497+ public Void visit (SessionInformation value ) {
498+ sessionInformation = value ;
499+ if (onSessionInformation == null ) return null ;
500+ onSessionInformation .accept (value );
501+ return null ;
502+ }
503+
504+
426505 @ Override
427506 public Void visit (SessionTerminated value ) {
507+ if (sessionTerminatedFuture != null ) {
508+ sessionTerminatedFuture .complete (sessionInformation );
509+ }
428510 return null ;
429511 }
430512
0 commit comments