Skip to content

Commit 2899dfd

Browse files
authored
Add support for Azure Shared Access Signatures (#1537)
* Add support for Azure SAS tokens * Add support for bearer tokens * Allow both storage account sas and container sas * Apply clang format * Add SAS to documentation * Remove token_credential again since it does not support renew * Format ipynb * Get SAS during test time * Install az cli locally * Install azure-cli inside test container * Make sure commands are valid * Ensure env is setup
1 parent 312044a commit 2899dfd

File tree

4 files changed

+118
-30
lines changed

4 files changed

+118
-30
lines changed

.github/workflows/build.wheel.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ run_test() {
88
CPYTHON_VERSION=$($entry -c 'import sys; print(str(sys.version_info[0])+str(sys.version_info[1]))')
99
(cd wheelhouse && $entry -m pip install tensorflow_io_gcs_filesystem-*-cp${CPYTHON_VERSION}-*.whl)
1010
(cd wheelhouse && $entry -m pip install tensorflow_io-*-cp${CPYTHON_VERSION}-*.whl)
11-
$entry -m pip install -q pytest pytest-benchmark pytest-xdist boto3 fastavro avro-python3 scikit-image pandas pyarrow==3.0.0 google-cloud-pubsub==2.1.0 google-cloud-bigtable==1.6.0 google-cloud-bigquery-storage==1.1.0 google-cloud-bigquery==2.3.1 google-cloud-storage==1.32.0 PyYAML==5.3.1 azure-storage-blob==12.8.1
11+
$entry -m pip install -q pytest pytest-benchmark pytest-xdist boto3 fastavro avro-python3 scikit-image pandas pyarrow==3.0.0 google-cloud-pubsub==2.1.0 google-cloud-bigtable==1.6.0 google-cloud-bigquery-storage==1.1.0 google-cloud-bigquery==2.3.1 google-cloud-storage==1.32.0 PyYAML==5.3.1 azure-storage-blob==12.8.1 azure-cli==2.29.0
1212
(cd tests && $entry -m pytest --benchmark-disable -v --import-mode=append --forked --numprocesses=auto --dist loadfile $(find . -type f \( -iname "test_*.py" ! \( -iname "test_standalone_*.py" \) \)))
1313
(cd tests && $entry -m pytest --benchmark-disable -v --import-mode=append $(find . -type f \( -iname "test_standalone_*.py" \)))
1414
}

docs/tutorials/azure.ipynb

+4
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,10 @@
278278
"else:\n",
279279
" # Replace <key> with Azure Storage Key, and <account> with Azure Storage Account\n",
280280
" os.environ['TF_AZURE_STORAGE_KEY'] = '<key>'\n",
281+
" account_name = '<account>'\n",
282+
"\n",
283+
" # Alternatively, you can use a shared access signature (SAS) to authenticate with the Azure Storage Account\n",
284+
" os.environ['TF_AZURE_STORAGE_SAS'] = '<your sas>'\n",
281285
" account_name = '<account>'"
282286
]
283287
},

tensorflow_io/core/filesystems/az/az_filesystem.cc

+31-18
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,20 @@ std::string errno_to_string() {
184184
}
185185

186186
std::shared_ptr<azure::storage_lite::storage_credential> get_credential(
187-
const std::string& account) {
188-
const auto key = std::getenv("TF_AZURE_STORAGE_KEY");
189-
if (key != nullptr) {
187+
const std::string& account, const std::string& container) {
188+
const std::string sas_account_container_env =
189+
"TF_AZURE_STORAGE_" + account + "_" + container + "_SAS";
190+
const std::string sas_account_env = "TF_AZURE_STORAGE_" + account + "_SAS";
191+
if (const auto sas = std::getenv(sas_account_container_env.c_str())) {
192+
return std::make_shared<
193+
azure::storage_lite::shared_access_signature_credential>(sas);
194+
} else if (const auto sas = std::getenv(sas_account_env.c_str())) {
195+
return std::make_shared<
196+
azure::storage_lite::shared_access_signature_credential>(sas);
197+
} else if (const auto sas = std::getenv("TF_AZURE_STORAGE_SAS")) {
198+
return std::make_shared<
199+
azure::storage_lite::shared_access_signature_credential>(sas);
200+
} else if (const auto key = std::getenv("TF_AZURE_STORAGE_KEY")) {
190201
return std::make_shared<azure::storage_lite::shared_key_credential>(account,
191202
key);
192203
} else {
@@ -195,7 +206,7 @@ std::shared_ptr<azure::storage_lite::storage_credential> get_credential(
195206
}
196207

197208
azure::storage_lite::blob_client_wrapper CreateAzBlobClientWrapper(
198-
const std::string& account) {
209+
const std::string& account, const std::string& container) {
199210
azure::storage_lite::logger::set_logger(
200211
[](azure::storage_lite::log_level level, const std::string& log_msg) {
201212
switch (level) {
@@ -232,7 +243,7 @@ azure::storage_lite::blob_client_wrapper CreateAzBlobClientWrapper(
232243
const auto blob_endpoint =
233244
std::string(blob_endpoint_env ? blob_endpoint_env : "");
234245

235-
auto credentials = get_credential(account);
246+
auto credentials = get_credential(account, container);
236247
auto storage_account = std::make_shared<azure::storage_lite::storage_account>(
237248
account, credentials, use_https, blob_endpoint);
238249
auto blob_client =
@@ -324,10 +335,12 @@ class AzBlobRandomAccessFile {
324335
TF_SetStatus(status, TF_OK, "");
325336
return 0;
326337
}
327-
auto blob_client = CreateAzBlobClientWrapper(account_);
338+
auto blob_client = CreateAzBlobClientWrapper(account_, container_);
328339
auto blob_property = blob_client.get_blob_property(container_, object_);
329340
if (errno != 0) {
330-
TF_SetStatus(status, TF_INTERNAL, "Failed to get properties");
341+
std::string error_message =
342+
absl::StrCat("Failed to get properties ", errno);
343+
TF_SetStatus(status, TF_INTERNAL, error_message.c_str());
331344
return 0;
332345
}
333346
int64_t file_size = blob_property.size;
@@ -433,7 +446,7 @@ class AzBlobWritableFile {
433446
return;
434447
}
435448

436-
auto blob_client = CreateAzBlobClientWrapper(account_);
449+
auto blob_client = CreateAzBlobClientWrapper(account_, container_);
437450
blob_client.upload_file_to_blob(tmp_content_filename_, container_, object_);
438451
if (errno != 0) {
439452
std::string error_message =
@@ -476,7 +489,7 @@ Status GetMatchingPaths(const std::string& pattern, std::vector<std::string>* re
476489
TF_RETURN_IF_ERROR(
477490
ParseAzBlobPathClass(fixed_prefix, true, &account, &container, &object));
478491

479-
auto blob_client = CreateAzBlobClientWrapper(account);
492+
auto blob_client = CreateAzBlobClientWrapper(account, container);
480493

481494
std::vector<std::string> blobs;
482495
TF_RETURN_IF_ERROR(ListResources(fixed_prefix, "", blob_client, &blobs));
@@ -637,7 +650,7 @@ static void CreateDir(const TF_Filesystem* filesystem, const char* path,
637650
}
638651

639652
// Blob storage has virtual folders. We can make sure the container exists
640-
auto blob_client_wrapper = CreateAzBlobClientWrapper(account);
653+
auto blob_client_wrapper = CreateAzBlobClientWrapper(account, container);
641654

642655
if (blob_client_wrapper.container_exists(container)) {
643656
TF_SetStatus(status, TF_OK, "");
@@ -667,7 +680,7 @@ static void DeleteFile(const TF_Filesystem* filesystem, const char* path,
667680
return;
668681
}
669682

670-
auto blob_client = CreateAzBlobClientWrapper(account);
683+
auto blob_client = CreateAzBlobClientWrapper(account, container);
671684

672685
blob_client.delete_blob(container, object);
673686
if (errno != 0) {
@@ -698,7 +711,7 @@ static void DeleteDir(const TF_Filesystem* filesystem, const char* path,
698711
return;
699712
}
700713

701-
auto blob_client = CreateAzBlobClientWrapper(account);
714+
auto blob_client = CreateAzBlobClientWrapper(account, container);
702715

703716
// Check container exists
704717
// Just pull out the first path component representing the container
@@ -764,7 +777,7 @@ static void RenameFile(const TF_Filesystem* filesystem, const char* src,
764777
return;
765778
}
766779

767-
auto blob_client = CreateAzBlobClientWrapper(src_account);
780+
auto blob_client = CreateAzBlobClientWrapper(src_account, src_container);
768781

769782
blob_client.start_copy(src_container, src_object, dst_container, dst_object);
770783
if (errno != 0) {
@@ -860,7 +873,7 @@ static void PathExists(const TF_Filesystem* filesystem, const char* path,
860873
return;
861874
}
862875

863-
auto blob_client = CreateAzBlobClientWrapper(account);
876+
auto blob_client = CreateAzBlobClientWrapper(account, container);
864877
auto blob_exists = blob_client.blob_exists(container, object);
865878
if (errno != 0) {
866879
std::string error_message = absl::StrCat(
@@ -889,7 +902,7 @@ static bool IsDirectory(const TF_Filesystem* filesystem, const char* path,
889902
return false;
890903
}
891904

892-
auto blob_client = CreateAzBlobClientWrapper(account);
905+
auto blob_client = CreateAzBlobClientWrapper(account, container);
893906

894907
if (container.empty()) {
895908
TF_SetStatus(status, TF_UNIMPLEMENTED,
@@ -935,7 +948,7 @@ static void Stat(const TF_Filesystem* filesystem, const char* path,
935948
return;
936949
}
937950

938-
auto blob_client = CreateAzBlobClientWrapper(account);
951+
auto blob_client = CreateAzBlobClientWrapper(account, container);
939952

940953
if (IsDirectory(filesystem, path, status)) {
941954
stats->length = 0;
@@ -972,7 +985,7 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
972985
return 0;
973986
}
974987

975-
auto blob_client = CreateAzBlobClientWrapper(account);
988+
auto blob_client = CreateAzBlobClientWrapper(account, container);
976989

977990
std::string continuation_token;
978991
if (container.empty()) {
@@ -1049,7 +1062,7 @@ static int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
10491062
return 0;
10501063
}
10511064

1052-
auto blob_client = CreateAzBlobClientWrapper(account);
1065+
auto blob_client = CreateAzBlobClientWrapper(account, container);
10531066
auto blob_property = blob_client.get_blob_property(container, object);
10541067
if (errno != 0) {
10551068
std::string error_message = absl::StrCat(

tests/test_azure.py

+82-11
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Tests for Azure File System."""
1616

1717
import os
18+
import subprocess
1819
import sys
1920
import pytest
2021

@@ -27,22 +28,13 @@
2728
pytest.skip("TODO: skip macOS", allow_module_level=True)
2829

2930

30-
class AZFSTest(tf.test.TestCase):
31+
class AZFSTestBase:
3132
"""[summary]
3233
3334
Args:
3435
test {[type]} -- [description]
3536
"""
3637

37-
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
38-
39-
os.environ["TF_AZURE_USE_DEV_STORAGE"] = "1"
40-
41-
self.account = "devstoreaccount1"
42-
self.container = "aztest"
43-
self.path_root = "az://" + os.path.join(self.account, self.container)
44-
super().__init__(methodName)
45-
4638
def _path_to(self, path):
4739
return os.path.join(self.path_root, path)
4840

@@ -59,7 +51,9 @@ def test_also_works_with_full_dns_name(self):
5951
a path of the form
6052
az://<account>.blob.core.windows.net/<container>/<path>
6153
"""
62-
file_name = self.account + ".blob.core.windows.net" + self.container
54+
file_name = "az://" + os.path.join(
55+
self.account + ".blob.core.windows.net", self.container
56+
)
6357
if not tf.io.gfile.isdir(file_name):
6458
tf.io.gfile.makedirs(file_name)
6559

@@ -208,5 +202,82 @@ def _test_read_file_offset_and_dataset(self):
208202
assert i == 2
209203

210204

205+
class AZFSTest(tf.test.TestCase, AZFSTestBase):
206+
"""Run tests for azfs backend using account key authentication."""
207+
208+
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
209+
210+
self.account = "devstoreaccount1"
211+
self.container = "aztest"
212+
self.path_root = "az://" + os.path.join(self.account, self.container)
213+
super().__init__(methodName)
214+
215+
def setUp(self):
216+
super().setUp()
217+
218+
os.environ["TF_AZURE_USE_DEV_STORAGE"] = "1"
219+
220+
221+
class AZFSSASTest(tf.test.TestCase, AZFSTestBase):
222+
"""Run tests for azfs backend using shared access signature authentication."""
223+
224+
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
225+
self.account = "devstoreaccount1"
226+
self.container = "aztest"
227+
self.path_root = "az://" + os.path.join(self.account, self.container)
228+
super().__init__(methodName)
229+
230+
def setUp(self):
231+
super().setUp()
232+
233+
if "TF_AZURE_USE_DEV_STORAGE" in os.environ:
234+
del os.environ["TF_AZURE_USE_DEV_STORAGE"]
235+
236+
os.environ["TF_AZURE_STORAGE_USE_HTTP"] = "1"
237+
os.environ[
238+
"TF_AZURE_STORAGE_BLOB_ENDPOINT"
239+
] = "127.0.0.1:10000/devstoreaccount1"
240+
241+
sas_end = (
242+
subprocess.check_output(["date", "--date", "1 days", r"+%FT%TZ"])
243+
.decode()
244+
.rstrip()
245+
)
246+
247+
env = os.environ.copy()
248+
env["AZURE_STORAGE_CONNECTION_STRING"] = (
249+
"DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;"
250+
"AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;"
251+
"BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;"
252+
"QueueEndpoint=http://127.0.0.1:10001/devstoreaccount1;"
253+
"TableEndpoint=http://127.0.0.1:10002/devstoreaccount1;"
254+
)
255+
256+
os.environ["TF_AZURE_STORAGE_SAS"] = (
257+
subprocess.check_output(
258+
[
259+
"az",
260+
"storage",
261+
"account",
262+
"generate-sas",
263+
"-otsv",
264+
"--permissions",
265+
"acdlpruw",
266+
"--resource-types",
267+
"sco",
268+
"--services",
269+
"b",
270+
"--expiry",
271+
sas_end,
272+
"--account-name",
273+
"devstoreaccount1",
274+
],
275+
env=env,
276+
)
277+
.decode()
278+
.rstrip()
279+
)
280+
281+
211282
if __name__ == "__main__":
212283
tf.test.main()

0 commit comments

Comments
 (0)