Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 112 additions & 84 deletions core/src/main/java/com/google/adk/sessions/VertexAiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
import com.google.common.base.Splitter;
import com.google.common.collect.Iterables;
import com.google.genai.types.HttpOptions;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Single;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeoutException;
import javax.annotation.Nullable;
import okhttp3.ResponseBody;
import org.slf4j.Logger;
Expand Down Expand Up @@ -46,104 +50,124 @@ final class VertexAiClient {
new HttpApiClient(Optional.of(project), Optional.of(location), credentials, httpOptions);
}

@Nullable
JsonNode createSession(
Maybe<JsonNode> createSession(
String reasoningEngineId, String userId, ConcurrentMap<String, Object> state) {
ConcurrentHashMap<String, Object> sessionJsonMap = new ConcurrentHashMap<>();
sessionJsonMap.put("userId", userId);
if (state != null) {
sessionJsonMap.put("sessionState", state);
}

String sessId;
String operationId;
try {
String sessionJson = objectMapper.writeValueAsString(sessionJsonMap);
try (ApiResponse apiResponse =
apiClient.request(
"POST", "reasoningEngines/" + reasoningEngineId + "/sessions", sessionJson)) {
logger.debug("Create Session response {}", apiResponse.getResponseBody());
if (apiResponse == null || apiResponse.getResponseBody() == null) {
return null;
}

JsonNode jsonResponse = getJsonResponse(apiResponse);
if (jsonResponse == null) {
return null;
}
String sessionName = jsonResponse.get("name").asText();
List<String> parts = Splitter.on('/').splitToList(sessionName);
sessId = parts.get(parts.size() - 3);
operationId = Iterables.getLast(parts);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
return Single.fromCallable(() -> objectMapper.writeValueAsString(sessionJsonMap))
.flatMap(
sessionJson ->
performApiRequest(
"POST", "reasoningEngines/" + reasoningEngineId + "/sessions", sessionJson))
.flatMapMaybe(
apiResponse -> {
logger.debug("Create Session response {}", apiResponse.getResponseBody());
return getJsonResponse(apiResponse);
})
.flatMap(
jsonResponse -> {
String sessionName = jsonResponse.get("name").asText();
List<String> parts = Splitter.on('/').splitToList(sessionName);
String sessId = parts.get(parts.size() - 3);
String operationId = Iterables.getLast(parts);

return pollOperation(operationId, 0).andThen(getSession(reasoningEngineId, sessId));
});
}

for (int i = 0; i < MAX_RETRY_ATTEMPTS; i++) {
try (ApiResponse lroResponse = apiClient.request("GET", "operations/" + operationId, "")) {
JsonNode lroJsonResponse = getJsonResponse(lroResponse);
if (lroJsonResponse != null && lroJsonResponse.get("done") != null) {
break;
}
}
try {
SECONDS.sleep(1);
} catch (InterruptedException e) {
logger.warn("Error during sleep", e);
Thread.currentThread().interrupt();
}
/**
* Polls the status of a long-running operation.
*
* @param operationId The ID of the operation to poll.
* @param attempt The current retry attempt number (starting from 0).
* @return A Completable that completes when the operation is done, or errors with
* TimeoutException if max retries are exceeded.
*/
private Completable pollOperation(String operationId, int attempt) {
if (attempt >= MAX_RETRY_ATTEMPTS) {
return Completable.error(
new TimeoutException("Operation " + operationId + " did not complete in time."));
}
return getSession(reasoningEngineId, sessId);
return performApiRequest("GET", "operations/" + operationId, "")
.flatMapMaybe(VertexAiClient::getJsonResponse)
.flatMapCompletable(
lroJsonResponse -> {
if (lroJsonResponse != null && lroJsonResponse.get("done") != null) {
return Completable.complete(); // Operation is done
} else {
// Not done, retry after a delay
return Completable.timer(1, SECONDS)
.andThen(pollOperation(operationId, attempt + 1));
}
});
}

JsonNode listSessions(String reasoningEngineId, String userId) {
try (ApiResponse apiResponse =
apiClient.request(
Maybe<JsonNode> listSessions(String reasoningEngineId, String userId) {
return performApiRequest(
"GET",
"reasoningEngines/" + reasoningEngineId + "/sessions?filter=user_id=" + userId,
"")) {
return getJsonResponse(apiResponse);
}
"")
.flatMapMaybe(VertexAiClient::getJsonResponse);
}

JsonNode listEvents(String reasoningEngineId, String sessionId) {
try (ApiResponse apiResponse =
apiClient.request(
Maybe<JsonNode> listEvents(String reasoningEngineId, String sessionId) {
return performApiRequest(
"GET",
"reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + "/events",
"")) {
logger.debug("List events response {}", apiResponse);
return getJsonResponse(apiResponse);
}
"")
.doOnSuccess(apiResponse -> logger.debug("List events response {}", apiResponse))
.flatMapMaybe(VertexAiClient::getJsonResponse);
}

JsonNode getSession(String reasoningEngineId, String sessionId) {
try (ApiResponse apiResponse =
apiClient.request(
"GET", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "")) {
return getJsonResponse(apiResponse);
}
Maybe<JsonNode> getSession(String reasoningEngineId, String sessionId) {
return performApiRequest(
"GET", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "")
.flatMapMaybe(apiResponse -> getJsonResponse(apiResponse));
}

void deleteSession(String reasoningEngineId, String sessionId) {
try (ApiResponse response =
apiClient.request(
"DELETE", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "")) {}
Completable deleteSession(String reasoningEngineId, String sessionId) {
return performApiRequest(
"DELETE", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "")
.doOnSuccess(ApiResponse::close)
.ignoreElement();
}

void appendEvent(String reasoningEngineId, String sessionId, String eventJson) {
try (ApiResponse response =
apiClient.request(
Completable appendEvent(String reasoningEngineId, String sessionId, String eventJson) {
return performApiRequest(
"POST",
"reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + ":appendEvent",
eventJson)) {
if (response.getResponseBody().string().contains("com.google.genai.errors.ClientException")) {
logger.warn("Failed to append event: {}", eventJson);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
eventJson)
.flatMapCompletable(
response -> {
try (response) {
ResponseBody responseBody = response.getResponseBody();
if (responseBody != null) {
String responseString = responseBody.string();
if (responseString.contains("com.google.genai.errors.ClientException")) {
logger.warn("Failed to append event: {}", eventJson);
}
}
return Completable.complete();
} catch (IOException e) {
return Completable.error(new UncheckedIOException(e));
}
});
}

/**
* Performs an API request and returns a Single emitting the ApiResponse.
*
* <p>Note: The caller is responsible for closing the returned {@link ApiResponse}.
*/
private Single<ApiResponse> performApiRequest(String method, String path, String body) {
return Single.fromCallable(
() -> {
return apiClient.request(method, path, body);
});
}

/**
Expand All @@ -152,19 +176,23 @@ void appendEvent(String reasoningEngineId, String sessionId, String eventJson) {
* @throws UncheckedIOException if parsing fails.
*/
@Nullable
private static JsonNode getJsonResponse(ApiResponse apiResponse) {
if (apiResponse == null || apiResponse.getResponseBody() == null) {
return null;
}
private static Maybe<JsonNode> getJsonResponse(ApiResponse apiResponse) {
try {
ResponseBody responseBody = apiResponse.getResponseBody();
String responseString = responseBody.string();
if (responseString.isEmpty()) {
return null;
if (apiResponse == null || apiResponse.getResponseBody() == null) {
return Maybe.empty();
}
try {
ResponseBody responseBody = apiResponse.getResponseBody();
String responseString = responseBody.string(); // Read body here
if (responseString.isEmpty()) {
return Maybe.empty();
}
return Maybe.just(objectMapper.readTree(responseString));
} catch (IOException e) {
return Maybe.error(new UncheckedIOException(e));
}
return objectMapper.readTree(responseString);
} catch (IOException e) {
throw new UncheckedIOException(e);
} finally {
apiResponse.close();
}
}
}
Loading