Skip to content

Commit b957e64

Browse files
authored
Fix for empty SavedModelBundle tags. (#611)
1 parent 3f94211 commit b957e64

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,25 @@ public void export() throws IOException {
312312
private final Map<String, SessionFunction> functions = new LinkedHashMap<>();
313313
}
314314

315+
/**
316+
* Load a saved model from an export directory. The model that is being loaded should be created
317+
* using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model
318+
* API</a>.
319+
*
320+
* <p>This method is a shorthand for:
321+
*
322+
* <pre>{@code
323+
* SavedModelBundle.loader().load();
324+
* }</pre>
325+
*
326+
* @param exportDir the directory path containing a saved model.
327+
* @return a bundle containing the graph and associated session.
328+
*/
329+
public static SavedModelBundle load(String exportDir) {
330+
Loader loader = loader(exportDir);
331+
return loader.load();
332+
}
333+
315334
/**
316335
* Load a saved model from an export directory. The model that is being loaded should be created
317336
* using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model
@@ -329,9 +348,7 @@ public void export() throws IOException {
329348
*/
330349
public static SavedModelBundle load(String exportDir, String... tags) {
331350
Loader loader = loader(exportDir);
332-
if (tags != null && tags.length > 0) {
333-
loader.withTags(tags);
334-
}
351+
loader.withTags(tags);
335352
return loader.load();
336353
}
337354

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ public class SavedModelBundleTest {
6868

6969
@Test
7070
public void load() {
71+
try (SavedModelBundle bundle = SavedModelBundle.load(SAVED_MODEL_PATH)) {
72+
assertNotNull(bundle.session());
73+
assertNotNull(bundle.graph());
74+
assertNotNull(bundle.metaGraphDef());
75+
}
7176
try (SavedModelBundle bundle = SavedModelBundle.load(SAVED_MODEL_PATH, "serve")) {
7277
assertNotNull(bundle.session());
7378
assertNotNull(bundle.graph());

0 commit comments

Comments
 (0)