diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 0228519e42c..dbbb1ab759d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -312,6 +312,25 @@ public void export() throws IOException { private final Map functions = new LinkedHashMap<>(); } + /** + * Load a saved model from an export directory. The model that is being loaded should be created + * using the Saved Model + * API. + * + *

This method is a shorthand for: + * + *

{@code
+   * SavedModelBundle.loader().load();
+   * }
+ * + * @param exportDir the directory path containing a saved model. + * @return a bundle containing the graph and associated session. + */ + public static SavedModelBundle load(String exportDir) { + Loader loader = loader(exportDir); + return loader.load(); + } + /** * Load a saved model from an export directory. The model that is being loaded should be created * using the Saved Model @@ -329,9 +348,7 @@ public void export() throws IOException { */ public static SavedModelBundle load(String exportDir, String... tags) { Loader loader = loader(exportDir); - if (tags != null && tags.length > 0) { - loader.withTags(tags); - } + loader.withTags(tags); return loader.load(); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 4b452984574..5eb3bf71660 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -68,6 +68,11 @@ public class SavedModelBundleTest { @Test public void load() { + try (SavedModelBundle bundle = SavedModelBundle.load(SAVED_MODEL_PATH)) { + assertNotNull(bundle.session()); + assertNotNull(bundle.graph()); + assertNotNull(bundle.metaGraphDef()); + } try (SavedModelBundle bundle = SavedModelBundle.load(SAVED_MODEL_PATH, "serve")) { assertNotNull(bundle.session()); assertNotNull(bundle.graph());