Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class QueryStringConstants(object):
SIGNED_KEY_SERVICE = "sks"
SIGNED_KEY_VERSION = "skv"
SIGNED_ENCRYPTION_SCOPE = "ses"
SIGNED_REQUEST_HEADERS = "srh"
SIGNED_REQUEST_QUERY_PARAMS = "srq"
SIGNED_KEY_DELEGATED_USER_TID = "skdutid"
SIGNED_DELEGATED_USER_OID = "sduoid"

Expand Down Expand Up @@ -81,6 +83,8 @@ def to_list():
QueryStringConstants.SIGNED_KEY_SERVICE,
QueryStringConstants.SIGNED_KEY_VERSION,
QueryStringConstants.SIGNED_ENCRYPTION_SCOPE,
QueryStringConstants.SIGNED_REQUEST_HEADERS,
QueryStringConstants.SIGNED_REQUEST_QUERY_PARAMS,
QueryStringConstants.SIGNED_KEY_DELEGATED_USER_TID,
QueryStringConstants.SIGNED_DELEGATED_USER_OID,
# for ADLS
Expand Down Expand Up @@ -225,6 +229,18 @@ def add_override_response_headers(
self._add_query(QueryStringConstants.SIGNED_CONTENT_LANGUAGE, content_language)
self._add_query(QueryStringConstants.SIGNED_CONTENT_TYPE, content_type)

def add_request_headers(self, request_headers):
if not request_headers:
return
serialized = [str(k) + ":" + str(v) for k, v in request_headers.items()]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This is already a Dict[str, str] right? So, you shouldn't need the str on the key and value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okie

self._add_query(QueryStringConstants.SIGNED_REQUEST_HEADERS, "\n".join(serialized) + "\n")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per the spec, request headers should have a trailing \n in addition to the \n separator for different signed values, and request query parameters should have a prefix \n (opposite to request headers).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure it needs one at the end?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just what the spec says ¯_(ツ)_/¯


def add_request_query_params(self, request_query_params):
if not request_query_params:
return
serialized = [str(k) + ":" + str(v) for k, v in request_query_params.items()]
self._add_query(QueryStringConstants.SIGNED_REQUEST_QUERY_PARAMS, "\n" + "\n".join(serialized))

def add_account_signature(self, account_name, account_key):
def get_value_to_append(query):
return_value = self.query_dict.get(query) or ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def generate_blob(
content_language: Optional[str] = None,
content_type: Optional[str] = None,
user_delegation_oid: Optional[str] = None,
request_headers: Optional[Dict[str, str]] = None,
request_query_params: Optional[Dict[str, str]] = None,
sts_hook: Optional[Callable[[str], None]] = None,
**kwargs: Any
) -> str:
Expand Down Expand Up @@ -141,6 +143,12 @@ def generate_blob(
Specifies the Entra ID of the user that is authorized to use the resulting SAS URL.
The resulting SAS URL must be used in conjunction with an Entra ID token that has been
issued to the user specified in this value.
:param Dict[str, str] request_headers:
If specified, both the correct request header(s) and corresponding values must be present,
or the request will fail.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this description a little better? As is it doesn't make a lot of sense. Something like "Specifies a set of headers and their corresponding values that must be present in the request when using this SAS.". Somethng similar for query params.

This is internal doc, so it doesn't matter as much but please change the public docs on the public SAS functions. I didn't mention it before because we were just trying to get the API views done.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take your suggestion as-is here

:param Dict[str, str] request_query_params:
If specified, both the correct query parameter(s) and corresponding values must be present,
or the request will fail.
:param sts_hook:
For debugging purposes only. If provided, the hook is called with the string to sign
that was used to generate the SAS.
Expand All @@ -166,6 +174,8 @@ def generate_blob(
content_type)
sas.add_encryption_scope(**kwargs)
sas.add_info_for_hns_account(**kwargs)
sas.add_request_headers(request_headers)
sas.add_request_query_params(request_query_params)
sas.add_resource_signature(self.account_name, self.account_key, resource_path,
user_delegation_key=self.user_delegation_key)

Expand All @@ -188,6 +198,8 @@ def generate_container(
content_language: Optional[str] = None,
content_type: Optional[str] = None,
user_delegation_oid: Optional[str] = None,
request_headers: Optional[Dict[str, str]] = None,
request_query_params: Optional[Dict[str, str]] = None,
sts_hook: Optional[Callable[[str], None]] = None,
**kwargs: Any
) -> str:
Expand Down Expand Up @@ -251,6 +263,12 @@ def generate_container(
Specifies the Entra ID of the user that is authorized to use the resulting SAS URL.
The resulting SAS URL must be used in conjunction with an Entra ID token that has been
issued to the user specified in this value.
:param Dict[str, str] request_headers:
If specified, both the correct request header(s) and corresponding values must be present,
or the request will fail.
:param Dict[str, str] request_query_params:
If specified, both the correct query parameter(s) and corresponding values must be present,
or the request will fail.
:param sts_hook:
For debugging purposes only. If provided, the hook is called with the string to sign
that was used to generate the SAS.
Expand All @@ -268,6 +286,8 @@ def generate_container(
content_type)
sas.add_encryption_scope(**kwargs)
sas.add_info_for_hns_account(**kwargs)
sas.add_request_headers(request_headers)
sas.add_request_query_params(request_query_params)
sas.add_resource_signature(self.account_name, self.account_key, container_name,
user_delegation_key=self.user_delegation_key)

Expand Down Expand Up @@ -336,6 +356,8 @@ def add_resource_signature(self, account_name, account_key, path, user_delegatio
self.get_value_to_append(QueryStringConstants.SIGNED_RESOURCE) +
self.get_value_to_append(BlobQueryStringConstants.SIGNED_TIMESTAMP) +
self.get_value_to_append(QueryStringConstants.SIGNED_ENCRYPTION_SCOPE) +
self.get_value_to_append(QueryStringConstants.SIGNED_REQUEST_HEADERS) +
self.get_value_to_append(QueryStringConstants.SIGNED_REQUEST_QUERY_PARAMS) +
self.get_value_to_append(QueryStringConstants.SIGNED_CACHE_CONTROL) +
self.get_value_to_append(QueryStringConstants.SIGNED_CONTENT_DISPOSITION) +
self.get_value_to_append(QueryStringConstants.SIGNED_CONTENT_ENCODING) +
Expand Down Expand Up @@ -575,6 +597,8 @@ def generate_container_sas(
policy_id=policy_id,
ip=ip,
user_delegation_oid=user_delegation_oid,
request_headers=request_headers,
request_query_params=request_query_params,
sts_hook=sts_hook,
**kwargs
)
Expand Down Expand Up @@ -725,8 +749,10 @@ def generate_blob_sas(
start=start,
policy_id=policy_id,
ip=ip,
sts_hook=sts_hook,
user_delegation_oid=user_delegation_oid,
request_headers=request_headers,
request_query_params=request_query_params,
sts_hook=sts_hook,
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class QueryStringConstants(object):
SIGNED_KEY_SERVICE = "sks"
SIGNED_KEY_VERSION = "skv"
SIGNED_ENCRYPTION_SCOPE = "ses"
SIGNED_REQUEST_HEADERS = "srh"
SIGNED_REQUEST_QUERY_PARAMS = "srq"
SIGNED_KEY_DELEGATED_USER_TID = "skdutid"
SIGNED_DELEGATED_USER_OID = "sduoid"

Expand Down Expand Up @@ -81,6 +83,8 @@ def to_list():
QueryStringConstants.SIGNED_KEY_SERVICE,
QueryStringConstants.SIGNED_KEY_VERSION,
QueryStringConstants.SIGNED_ENCRYPTION_SCOPE,
QueryStringConstants.SIGNED_REQUEST_HEADERS,
QueryStringConstants.SIGNED_REQUEST_QUERY_PARAMS,
QueryStringConstants.SIGNED_KEY_DELEGATED_USER_TID,
QueryStringConstants.SIGNED_DELEGATED_USER_OID,
# for ADLS
Expand Down Expand Up @@ -225,6 +229,18 @@ def add_override_response_headers(
self._add_query(QueryStringConstants.SIGNED_CONTENT_LANGUAGE, content_language)
self._add_query(QueryStringConstants.SIGNED_CONTENT_TYPE, content_type)

def add_request_headers(self, request_headers):
if not request_headers:
return
serialized = [str(k) + ":" + str(v) for k, v in request_headers.items()]
self._add_query(QueryStringConstants.SIGNED_REQUEST_HEADERS, "\n".join(serialized) + "\n")

def add_request_query_params(self, request_query_params):
if not request_query_params:
return
serialized = [str(k) + ":" + str(v) for k, v in request_query_params.items()]
self._add_query(QueryStringConstants.SIGNED_REQUEST_QUERY_PARAMS, "\n" + "\n".join(serialized))

def add_account_signature(self, account_name, account_key):
def get_value_to_append(query):
return_value = self.query_dict.get(query) or ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def generate_file_system_sas(
permission=cast(Optional[Union["ContainerSasPermissions", str]], permission),
expiry=expiry,
user_delegation_oid=user_delegation_oid,
request_headers=request_headers,
request_query_params=request_query_params,
sts_hook=sts_hook,
**kwargs
)
Expand Down Expand Up @@ -353,6 +355,8 @@ def generate_directory_sas(
sdd=depth,
is_directory=True,
user_delegation_oid=user_delegation_oid,
request_headers=request_headers,
request_query_params=request_query_params,
sts_hook=sts_hook,
**kwargs
)
Expand Down Expand Up @@ -488,8 +492,10 @@ def generate_file_sas(
user_delegation_key=credential if not isinstance(credential, str) else None,
permission=cast(Optional[Union["BlobSasPermissions", str]], permission),
expiry=expiry,
sts_hook=sts_hook,
user_delegation_oid=user_delegation_oid,
request_headers=request_headers,
request_query_params=request_query_params,
sts_hook=sts_hook,
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class QueryStringConstants(object):
SIGNED_KEY_SERVICE = "sks"
SIGNED_KEY_VERSION = "skv"
SIGNED_ENCRYPTION_SCOPE = "ses"
SIGNED_REQUEST_HEADERS = "srh"
SIGNED_REQUEST_QUERY_PARAMS = "srq"
SIGNED_KEY_DELEGATED_USER_TID = "skdutid"
SIGNED_DELEGATED_USER_OID = "sduoid"

Expand Down Expand Up @@ -81,6 +83,8 @@ def to_list():
QueryStringConstants.SIGNED_KEY_SERVICE,
QueryStringConstants.SIGNED_KEY_VERSION,
QueryStringConstants.SIGNED_ENCRYPTION_SCOPE,
QueryStringConstants.SIGNED_REQUEST_HEADERS,
QueryStringConstants.SIGNED_REQUEST_QUERY_PARAMS,
QueryStringConstants.SIGNED_KEY_DELEGATED_USER_TID,
QueryStringConstants.SIGNED_DELEGATED_USER_OID,
# for ADLS
Expand Down Expand Up @@ -218,6 +222,18 @@ def add_override_response_headers(
self._add_query(QueryStringConstants.SIGNED_CONTENT_LANGUAGE, content_language)
self._add_query(QueryStringConstants.SIGNED_CONTENT_TYPE, content_type)

def add_request_headers(self, request_headers):
if not request_headers:
return
serialized = [str(k) + ":" + str(v) for k, v in request_headers.items()]
self._add_query(QueryStringConstants.SIGNED_REQUEST_HEADERS, "\n".join(serialized) + "\n")

def add_request_query_params(self, request_query_params):
if not request_query_params:
return
serialized = [str(k) + ":" + str(v) for k, v in request_query_params.items()]
self._add_query(QueryStringConstants.SIGNED_REQUEST_QUERY_PARAMS, "\n" + "\n".join(serialized))

def add_account_signature(self, account_name, account_key):
def get_value_to_append(query):
return_value = self.query_dict.get(query) or ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class QueryStringConstants(object):
SIGNED_KEY_SERVICE = "sks"
SIGNED_KEY_VERSION = "skv"
SIGNED_ENCRYPTION_SCOPE = "ses"
SIGNED_REQUEST_HEADERS = "srh"
SIGNED_REQUEST_QUERY_PARAMS = "srq"
SIGNED_KEY_DELEGATED_USER_TID = "skdutid"
SIGNED_DELEGATED_USER_OID = "sduoid"

Expand Down Expand Up @@ -81,6 +83,8 @@ def to_list():
QueryStringConstants.SIGNED_KEY_SERVICE,
QueryStringConstants.SIGNED_KEY_VERSION,
QueryStringConstants.SIGNED_ENCRYPTION_SCOPE,
QueryStringConstants.SIGNED_REQUEST_HEADERS,
QueryStringConstants.SIGNED_REQUEST_QUERY_PARAMS,
QueryStringConstants.SIGNED_KEY_DELEGATED_USER_TID,
QueryStringConstants.SIGNED_DELEGATED_USER_OID,
# for ADLS
Expand Down Expand Up @@ -226,6 +230,18 @@ def add_override_response_headers(
self._add_query(QueryStringConstants.SIGNED_CONTENT_LANGUAGE, content_language)
self._add_query(QueryStringConstants.SIGNED_CONTENT_TYPE, content_type)

def add_request_headers(self, request_headers):
if not request_headers:
return
serialized = [str(k) + ":" + str(v) for k, v in request_headers.items()]
self._add_query(QueryStringConstants.SIGNED_REQUEST_HEADERS, "\n".join(serialized) + "\n")

def add_request_query_params(self, request_query_params):
if not request_query_params:
return
serialized = [str(k) + ":" + str(v) for k, v in request_query_params.items()]
self._add_query(QueryStringConstants.SIGNED_REQUEST_QUERY_PARAMS, "\n" + "\n".join(serialized))

def add_account_signature(self, account_name, account_key):
def get_value_to_append(query):
return_value = self.query_dict.get(query) or ""
Expand Down