Skip to content

Commit d8e2269

Browse files
authored
change(ml): Drop AutoML model support (#894)
1 parent dae267c commit d8e2269

File tree

3 files changed

+3
-178
lines changed

3 files changed

+3
-178
lines changed

firebase_admin/ml.py

Lines changed: 2 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import time
2525
import os
2626
from urllib import parse
27-
import warnings
2827

2928
import requests
3029

@@ -33,14 +32,14 @@
3332
from firebase_admin import _utils
3433
from firebase_admin import exceptions
3534

36-
# pylint: disable=import-error,no-name-in-module
35+
# pylint: disable=import-error,no-member
3736
try:
3837
from firebase_admin import storage
3938
_GCS_ENABLED = True
4039
except ImportError:
4140
_GCS_ENABLED = False
4241

43-
# pylint: disable=import-error,no-name-in-module
42+
# pylint: disable=import-error,no-member
4443
try:
4544
import tensorflow as tf
4645
_TF_ENABLED = True
@@ -54,9 +53,6 @@
5453
_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,32}$')
5554
_GCS_TFLITE_URI_PATTERN = re.compile(
5655
r'^gs://(?P<bucket_name>[a-z0-9_.-]{3,63})/(?P<blob_name>.+)$')
57-
_AUTO_ML_MODEL_PATTERN = re.compile(
58-
r'^projects/(?P<project_id>[a-z0-9-]{6,30})/locations/(?P<location_id>[^/]+)/' +
59-
r'models/(?P<model_id>[A-Za-z0-9]+)$')
6056
_RESOURCE_NAME_PATTERN = re.compile(
6157
r'^projects/(?P<project_id>[a-z0-9-]{6,30})/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
6258
_OPERATION_NAME_PATTERN = re.compile(
@@ -388,11 +384,6 @@ def _init_model_source(data):
388384
gcs_tflite_uri = data.pop('gcsTfliteUri', None)
389385
if gcs_tflite_uri:
390386
return TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri)
391-
auto_ml_model = data.pop('automlModel', None)
392-
if auto_ml_model:
393-
warnings.warn('AutoML model support is deprecated and will be removed in the next '
394-
'major version.', DeprecationWarning)
395-
return TFLiteAutoMlSource(auto_ml_model=auto_ml_model)
396387
return None
397388

398389
@property
@@ -606,42 +597,6 @@ def as_dict(self, for_upload=False):
606597

607598
return {'gcsTfliteUri': self._gcs_tflite_uri}
608599

609-
610-
class TFLiteAutoMlSource(TFLiteModelSource):
611-
"""TFLite model source representing a tflite model created with AutoML.
612-
613-
AutoML model support is deprecated and will be removed in the next major version.
614-
"""
615-
616-
def __init__(self, auto_ml_model, app=None):
617-
warnings.warn('AutoML model support is deprecated and will be removed in the next '
618-
'major version.', DeprecationWarning)
619-
self._app = app
620-
self.auto_ml_model = auto_ml_model
621-
622-
def __eq__(self, other):
623-
if isinstance(other, self.__class__):
624-
return self.auto_ml_model == other.auto_ml_model
625-
return False
626-
627-
def __ne__(self, other):
628-
return not self.__eq__(other)
629-
630-
@property
631-
def auto_ml_model(self):
632-
"""Resource name of the model, created by the AutoML API or Cloud console."""
633-
return self._auto_ml_model
634-
635-
@auto_ml_model.setter
636-
def auto_ml_model(self, auto_ml_model):
637-
self._auto_ml_model = _validate_auto_ml_model(auto_ml_model)
638-
639-
def as_dict(self, for_upload=False):
640-
"""Returns a serializable representation of the object."""
641-
# Upload is irrelevant for auto_ml models
642-
return {'automlModel': self._auto_ml_model}
643-
644-
645600
class ListModelsPage:
646601
"""Represents a page of models in a Firebase project.
647602
@@ -786,11 +741,6 @@ def _validate_gcs_tflite_uri(uri):
786741
raise ValueError('GCS TFLite URI format is invalid.')
787742
return uri
788743

789-
def _validate_auto_ml_model(model):
790-
if not _AUTO_ML_MODEL_PATTERN.match(model):
791-
raise ValueError('Model resource name format is invalid.')
792-
return model
793-
794744

795745
def _validate_model_format(model_format):
796746
if not isinstance(model_format, ModelFormat):

integration/test_ml.py

Lines changed: 1 addition & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,18 @@
2222

2323
import pytest
2424

25-
import firebase_admin
2625
from firebase_admin import exceptions
2726
from firebase_admin import ml
2827
from tests import testutils
2928

3029

31-
# pylint: disable=import-error,no-name-in-module
30+
# pylint: disable=import-error, no-member
3231
try:
3332
import tensorflow as tf
3433
_TF_ENABLED = True
3534
except ImportError:
3635
_TF_ENABLED = False
3736

38-
try:
39-
from google.cloud import automl_v1
40-
_AUTOML_ENABLED = True
41-
except ImportError:
42-
_AUTOML_ENABLED = False
43-
4437
def _random_identifier(prefix):
4538
#pylint: disable=unused-variable
4639
suffix = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(8)])
@@ -159,14 +152,6 @@ def check_tflite_gcs_format(model, validation_error=None):
159152
assert model.model_hash is not None
160153

161154

162-
def check_tflite_automl_format(model):
163-
assert model.validation_error is None
164-
assert model.published is False
165-
assert model.model_format.model_source.auto_ml_model.startswith('projects/')
166-
# Automl models don't have validation errors since they are references
167-
# to valid automl models.
168-
169-
170155
@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True)
171156
def test_create_simple_model(firebase_model):
172157
check_model(firebase_model, NAME_AND_TAGS_ARGS)
@@ -392,50 +377,3 @@ def test_from_saved_model(saved_model_dir):
392377
assert created_model.validation_error is None
393378
finally:
394379
_clean_up_model(created_model)
395-
396-
397-
# Test AutoML functionality if AutoML is enabled.
398-
#'pip install google-cloud-automl' in the environment if you want _AUTOML_ENABLED = True
399-
# You will also need a predefined AutoML model named 'admin_sdk_integ_test1' to run the
400-
# successful test. (Test is skipped otherwise)
401-
402-
@pytest.fixture
403-
def automl_model():
404-
assert _AUTOML_ENABLED
405-
406-
# It takes > 20 minutes to train a model, so we expect a predefined AutoMl
407-
# model named 'admin_sdk_integ_test1' to exist in the project, or we skip
408-
# the test.
409-
automl_client = automl_v1.AutoMlClient()
410-
project_id = firebase_admin.get_app().project_id
411-
parent = automl_client.location_path(project_id, 'us-central1')
412-
models = automl_client.list_models(parent, filter_="display_name=admin_sdk_integ_test1")
413-
# Expecting exactly one. (Ok to use last one if somehow more than 1)
414-
automl_ref = None
415-
for model in models:
416-
automl_ref = model.name
417-
418-
# Skip if no pre-defined model. (It takes min > 20 minutes to train a model)
419-
if automl_ref is None:
420-
pytest.skip("No pre-existing AutoML model found. Skipping test")
421-
422-
source = ml.TFLiteAutoMlSource(automl_ref)
423-
tflite_format = ml.TFLiteFormat(model_source=source)
424-
ml_model = ml.Model(
425-
display_name=_random_identifier('TestModel_automl_'),
426-
tags=['test_automl'],
427-
model_format=tflite_format)
428-
model = ml.create_model(model=ml_model)
429-
yield model
430-
_clean_up_model(model)
431-
432-
@pytest.mark.skipif(not _AUTOML_ENABLED, reason='AutoML is required for this test.')
433-
def test_automl_model(automl_model):
434-
# This test looks for a predefined automl model with display_name = 'admin_sdk_integ_test1'
435-
automl_model.wait_for_unlocked()
436-
437-
check_model(automl_model, {
438-
'display_name': automl_model.display_name,
439-
'tags': ['test_automl'],
440-
})
441-
check_tflite_automl_format(automl_model)

tests/test_ml.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -121,18 +121,6 @@
121121
}
122122
TFLITE_FORMAT_2 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2)
123123

124-
AUTOML_MODEL_NAME = 'projects/111111111111/locations/us-central1/models/ICN7683346839371803263'
125-
AUTOML_MODEL_SOURCE = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME)
126-
TFLITE_FORMAT_JSON_3 = {
127-
'automlModel': AUTOML_MODEL_NAME,
128-
'sizeBytes': '3456789'
129-
}
130-
TFLITE_FORMAT_3 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_3)
131-
132-
AUTOML_MODEL_NAME_2 = 'projects/2222222222/locations/us-central1/models/ICN2222222222222222222'
133-
AUTOML_MODEL_NAME_JSON_2 = {'automlModel': AUTOML_MODEL_NAME_2}
134-
AUTOML_MODEL_SOURCE_2 = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME_2)
135-
136124
CREATED_UPDATED_MODEL_JSON_1 = {
137125
'name': MODEL_NAME_1,
138126
'displayName': DISPLAY_NAME_1,
@@ -423,14 +411,6 @@ def test_model_keyword_based_creation_and_setters(self):
423411
'tfliteModel': TFLITE_FORMAT_JSON_2
424412
}
425413

426-
model.model_format = TFLITE_FORMAT_3
427-
assert model.as_dict() == {
428-
'displayName': DISPLAY_NAME_2,
429-
'tags': TAGS_2,
430-
'tfliteModel': TFLITE_FORMAT_JSON_3
431-
}
432-
433-
434414
def test_gcs_tflite_model_format_source_creation(self):
435415
model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI)
436416
model_format = ml.TFLiteFormat(model_source=model_source)
@@ -442,17 +422,6 @@ def test_gcs_tflite_model_format_source_creation(self):
442422
}
443423
}
444424

445-
def test_auto_ml_tflite_model_format_source_creation(self):
446-
model_source = ml.TFLiteAutoMlSource(auto_ml_model=AUTOML_MODEL_NAME)
447-
model_format = ml.TFLiteFormat(model_source=model_source)
448-
model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format)
449-
assert model.as_dict() == {
450-
'displayName': DISPLAY_NAME_1,
451-
'tfliteModel': {
452-
'automlModel': AUTOML_MODEL_NAME
453-
}
454-
}
455-
456425
def test_source_creation_from_tflite_file(self):
457426
model_source = ml.TFLiteGCSModelSource.from_tflite_model_file(
458427
"my_model.tflite", "my_bucket")
@@ -466,13 +435,6 @@ def test_gcs_tflite_model_source_setters(self):
466435
assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2
467436
assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2
468437

469-
def test_auto_ml_tflite_model_source_setters(self):
470-
model_source = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME)
471-
model_source.auto_ml_model = AUTOML_MODEL_NAME_2
472-
assert model_source.auto_ml_model == AUTOML_MODEL_NAME_2
473-
assert model_source.as_dict() == AUTOML_MODEL_NAME_JSON_2
474-
475-
476438
def test_model_format_setters(self):
477439
model_format = ml.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE)
478440
model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2
@@ -483,14 +445,6 @@ def test_model_format_setters(self):
483445
}
484446
}
485447

486-
model_format.model_source = AUTOML_MODEL_SOURCE
487-
assert model_format.model_source == AUTOML_MODEL_SOURCE
488-
assert model_format.as_dict() == {
489-
'tfliteModel': {
490-
'automlModel': AUTOML_MODEL_NAME
491-
}
492-
}
493-
494448
def test_model_as_dict_for_upload(self):
495449
model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI)
496450
model_format = ml.TFLiteFormat(model_source=model_source)
@@ -576,23 +530,6 @@ def test_gcs_tflite_source_validation_errors(self, uri, exc_type):
576530
ml.TFLiteGCSModelSource(gcs_tflite_uri=uri)
577531
check_error(excinfo, exc_type)
578532

579-
@pytest.mark.parametrize('auto_ml_model, exc_type', [
580-
(123, TypeError),
581-
('abc', ValueError),
582-
('/projects/123456/locations/us-central1/models/noLeadingSlash', ValueError),
583-
('projects/123546/models/ICN123456', ValueError),
584-
('projects//locations/us-central1/models/ICN123456', ValueError),
585-
('projects/123456/locations//models/ICN123456', ValueError),
586-
('projects/123456/locations/us-central1/models/', ValueError),
587-
('projects/ABC/locations/us-central1/models/ICN123456', ValueError),
588-
('projects/123456/locations/us-central1/models/@#$%^&', ValueError),
589-
('projects/123456/locations/us-cent/ral1/models/ICN123456', ValueError),
590-
])
591-
def test_auto_ml_tflite_source_validation_errors(self, auto_ml_model, exc_type):
592-
with pytest.raises(exc_type) as excinfo:
593-
ml.TFLiteAutoMlSource(auto_ml_model=auto_ml_model)
594-
check_error(excinfo, exc_type)
595-
596533
def test_wait_for_unlocked_not_locked(self):
597534
model = ml.Model(display_name="not_locked")
598535
model.wait_for_unlocked()

0 commit comments

Comments
 (0)