From a5543817dcc5af2b1f56ca25a7b8c7f0302f2c2f Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 6 Jun 2025 15:41:57 -0400 Subject: [PATCH] Fix for empty SavedModelBundle tags. --- .../java/org/tensorflow/SavedModelBundle.java | 23 ++++++++++++++++--- .../org/tensorflow/SavedModelBundleTest.java | 5 ++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 0228519e42c..dbbb1ab759d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -312,6 +312,25 @@ public void export() throws IOException { private final Map functions = new LinkedHashMap<>(); } + /** + * Load a saved model from an export directory. The model that is being loaded should be created + * using the Saved Model + * API. + * + *

This method is a shorthand for: + * + *

{@code
+   * SavedModelBundle.loader().load();
+   * }
+ * + * @param exportDir the directory path containing a saved model. + * @return a bundle containing the graph and associated session. + */ + public static SavedModelBundle load(String exportDir) { + Loader loader = loader(exportDir); + return loader.load(); + } + /** * Load a saved model from an export directory. The model that is being loaded should be created * using the Saved Model @@ -329,9 +348,7 @@ public void export() throws IOException { */ public static SavedModelBundle load(String exportDir, String... tags) { Loader loader = loader(exportDir); - if (tags != null && tags.length > 0) { - loader.withTags(tags); - } + loader.withTags(tags); return loader.load(); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 4b452984574..5eb3bf71660 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -68,6 +68,11 @@ public class SavedModelBundleTest { @Test public void load() { + try (SavedModelBundle bundle = SavedModelBundle.load(SAVED_MODEL_PATH)) { + assertNotNull(bundle.session()); + assertNotNull(bundle.graph()); + assertNotNull(bundle.metaGraphDef()); + } try (SavedModelBundle bundle = SavedModelBundle.load(SAVED_MODEL_PATH, "serve")) { assertNotNull(bundle.session()); assertNotNull(bundle.graph());