Skip to content

Commit ef75e7b

Browse files
committed
feat: adds util function to get available first factors
- Moves is_recipe_initialized to supertokens.asyncio - Cleans up supertokens __init__ file to reduce redundancy - Adds test to ensure FactorIds class and method are in sync ref: supertokens/supertokens-node#1021
1 parent da78c98 commit ef75e7b

File tree

6 files changed

+126
-46
lines changed

6 files changed

+126
-46
lines changed

supertokens_python/__init__.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414

15-
from typing import Any, Dict, List, Optional
15+
from typing import List, Optional
1616

1717
from typing_extensions import Literal
1818

19-
from supertokens_python.framework.request import BaseRequest
2019
from supertokens_python.recipe_module import RecipeModule
2120
from supertokens_python.types import RecipeUserId
2221

@@ -30,6 +29,7 @@
3029
SupertokensExperimentalConfig,
3130
SupertokensInputConfig,
3231
SupertokensPublicConfig,
32+
get_request_from_user_context,
3333
)
3434

3535
# Some Pydantic models need a rebuild to resolve ForwardRefs
@@ -69,19 +69,10 @@ def get_all_cors_headers() -> List[str]:
6969
return Supertokens.get_instance().get_all_cors_headers()
7070

7171

72-
def get_request_from_user_context(
73-
user_context: Optional[Dict[str, Any]],
74-
) -> Optional[BaseRequest]:
75-
return Supertokens.get_instance().get_request_from_user_context(user_context)
76-
77-
7872
def convert_to_recipe_user_id(user_id: str) -> RecipeUserId:
7973
return RecipeUserId(user_id)
8074

8175

82-
is_recipe_initialized = Supertokens.is_recipe_initialized
83-
84-
8576
__all__ = [
8677
"AppInfo",
8778
"InputAppInfo",
@@ -95,5 +86,4 @@ def convert_to_recipe_user_id(user_id: str) -> RecipeUserId:
9586
"get_all_cors_headers",
9687
"get_request_from_user_context",
9788
"init",
98-
"is_recipe_initialized",
9989
]

supertokens_python/asyncio/__init__.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import Any, Dict, List, Optional, Union
1515

1616
from supertokens_python import Supertokens
17+
from supertokens_python.exceptions import BadInputError
1718
from supertokens_python.interfaces import (
1819
CreateUserIdMappingOkResult,
1920
DeleteUserIdMappingOkResult,
@@ -26,8 +27,9 @@
2627
)
2728
from supertokens_python.recipe.accountlinking.interfaces import GetUsersResult
2829
from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe
30+
from supertokens_python.recipe.session.interfaces import SessionContainer
2931
from supertokens_python.types import User
30-
from supertokens_python.types.base import AccountInfoInput
32+
from supertokens_python.types.base import AccountInfoInput, UserContext
3133

3234

3335
async def get_users_oldest_first(
@@ -172,3 +174,44 @@ async def list_users_by_account_info(
172174
do_union_of_account_info,
173175
user_context,
174176
)
177+
178+
179+
# Async not really required, but keeping for consistency
180+
async def is_recipe_initialized(recipe_id: str) -> bool:
181+
"""
182+
Check if a recipe is initialized.
183+
:param recipe_id: The ID of the recipe to check.
184+
:return: Whether the recipe is initialized.
185+
"""
186+
return any(
187+
recipe.get_recipe_id() == recipe_id
188+
for recipe in Supertokens.get_instance().recipe_modules
189+
)
190+
191+
192+
async def get_available_first_factors(
193+
tenant_id: str,
194+
session: Optional[SessionContainer],
195+
user_context: Optional[UserContext],
196+
):
197+
from supertokens_python.auth_utils import (
198+
filter_out_invalid_first_factors_or_throw_if_all_are_invalid,
199+
)
200+
from supertokens_python.recipe.multifactorauth.types import FactorIds
201+
202+
available_first_factors: List[str] = []
203+
204+
try:
205+
available_first_factors = (
206+
await filter_out_invalid_first_factors_or_throw_if_all_are_invalid(
207+
factor_ids=FactorIds.get_all_factors(),
208+
tenant_id=tenant_id,
209+
has_session=session is not None,
210+
user_context=user_context if user_context is not None else {},
211+
)
212+
)
213+
except BadInputError:
214+
# All the factors were invalid, so we let it pass through and return the empty list
215+
pass
216+
217+
return available_first_factors

supertokens_python/recipe/multifactorauth/types.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,27 @@ class NormalisedMultiFactorAuthConfig(
6161

6262

6363
class FactorIds:
64-
EMAILPASSWORD: Literal["emailpassword"] = "emailpassword"
65-
OTP_EMAIL: Literal["otp-email"] = "otp-email"
66-
OTP_PHONE: Literal["otp-phone"] = "otp-phone"
67-
LINK_EMAIL: Literal["link-email"] = "link-email"
68-
LINK_PHONE: Literal["link-phone"] = "link-phone"
69-
THIRDPARTY: Literal["thirdparty"] = "thirdparty"
70-
TOTP: Literal["totp"] = "totp"
71-
WEBAUTHN: Literal["webauthn"] = "webauthn"
64+
EMAILPASSWORD = "emailpassword"
65+
OTP_EMAIL = "otp-email"
66+
OTP_PHONE = "otp-phone"
67+
LINK_EMAIL = "link-email"
68+
LINK_PHONE = "link-phone"
69+
THIRDPARTY = "thirdparty"
70+
TOTP = "totp"
71+
WEBAUTHN = "webauthn"
72+
73+
@staticmethod
74+
def get_all_factors():
75+
return [
76+
FactorIds.EMAILPASSWORD,
77+
FactorIds.OTP_EMAIL,
78+
FactorIds.OTP_PHONE,
79+
FactorIds.LINK_EMAIL,
80+
FactorIds.LINK_PHONE,
81+
FactorIds.THIRDPARTY,
82+
FactorIds.TOTP,
83+
FactorIds.WEBAUTHN,
84+
]
7285

7386

7487
class FactorIdsAndType:

supertokens_python/supertokens.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
SuperTokensPlugin,
4343
SuperTokensPublicPlugin,
4444
)
45+
from supertokens_python.types.base import UserContext
4546
from supertokens_python.types.response import CamelCaseBaseModel
4647

4748
from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT
@@ -181,13 +182,13 @@ def __init__(
181182
self.mode = mode
182183

183184
def get_top_level_website_domain(
184-
self, request: Optional[BaseRequest], user_context: Dict[str, Any]
185+
self, request: Optional[BaseRequest], user_context: UserContext
185186
) -> str:
186187
return get_top_level_domain_for_same_site_resolution(
187188
self.get_origin(request, user_context).get_as_string_dangerous()
188189
)
189190

190-
def get_origin(self, request: Optional[BaseRequest], user_context: Dict[str, Any]):
191+
def get_origin(self, request: Optional[BaseRequest], user_context: UserContext):
191192
origin = self.__origin
192193
if origin is None:
193194
origin = self.__website_domain
@@ -211,7 +212,7 @@ def defaultImpl(o: Any):
211212

212213

213214
def manage_session_post_response(
214-
session: SessionContainer, response: BaseResponse, user_context: Dict[str, Any]
215+
session: SessionContainer, response: BaseResponse, user_context: UserContext
215216
):
216217
# Something similar happens in handle_error of session/recipe.py
217218
for mutator in session.response_mutators:
@@ -577,7 +578,7 @@ async def get_user_count(
577578
self,
578579
include_recipe_ids: Union[None, List[str]],
579580
tenant_id: Optional[str] = None,
580-
user_context: Optional[Dict[str, Any]] = None,
581+
user_context: Optional[UserContext] = None,
581582
) -> int:
582583
querier = Querier.get_instance(None)
583584
include_recipe_ids_str = None
@@ -601,7 +602,7 @@ async def create_user_id_mapping(
601602
external_user_id: str,
602603
external_user_id_info: Optional[str],
603604
force: Optional[bool],
604-
user_context: Optional[Dict[str, Any]],
605+
user_context: Optional[UserContext],
605606
) -> Union[
606607
CreateUserIdMappingOkResult,
607608
UnknownSupertokensUserIDError,
@@ -641,7 +642,7 @@ async def get_user_id_mapping(
641642
self,
642643
user_id: str,
643644
user_id_type: Optional[UserIDTypes],
644-
user_context: Optional[Dict[str, Any]],
645+
user_context: Optional[UserContext],
645646
) -> Union[GetUserIdMappingOkResult, UnknownMappingError]:
646647
querier = Querier.get_instance(None)
647648

@@ -676,7 +677,7 @@ async def delete_user_id_mapping(
676677
user_id: str,
677678
user_id_type: Optional[UserIDTypes],
678679
force: Optional[bool],
679-
user_context: Optional[Dict[str, Any]],
680+
user_context: Optional[UserContext],
680681
) -> DeleteUserIdMappingOkResult:
681682
querier = Querier.get_instance(None)
682683

@@ -708,7 +709,7 @@ async def update_or_delete_user_id_mapping_info(
708709
user_id: str,
709710
user_id_type: Optional[UserIDTypes],
710711
external_user_id_info: Optional[str],
711-
user_context: Optional[Dict[str, Any]],
712+
user_context: Optional[UserContext],
712713
) -> Union[UpdateOrDeleteUserIdMappingInfoOkResult, UnknownMappingError]:
713714
querier = Querier.get_instance(None)
714715

@@ -734,7 +735,7 @@ async def update_or_delete_user_id_mapping_info(
734735
raise_general_exception("Please upgrade the SuperTokens core to >= 3.15.0")
735736

736737
async def middleware(
737-
self, request: BaseRequest, response: BaseResponse, user_context: Dict[str, Any]
738+
self, request: BaseRequest, response: BaseResponse, user_context: UserContext
738739
) -> Union[BaseResponse, None]:
739740
from supertokens_python.recipe.session.recipe import SessionRecipe
740741

@@ -907,7 +908,7 @@ async def handle_supertokens_error(
907908
request: BaseRequest,
908909
err: Exception,
909910
response: BaseResponse,
910-
user_context: Dict[str, Any],
911+
user_context: UserContext,
911912
) -> Optional[BaseResponse]:
912913
log_debug_message("errorHandler: Started")
913914
log_debug_message(
@@ -935,7 +936,7 @@ async def handle_supertokens_error(
935936

936937
def get_request_from_user_context(
937938
self,
938-
user_context: Optional[Dict[str, Any]] = None,
939+
user_context: Optional[UserContext] = None,
939940
) -> Optional[BaseRequest]:
940941
if user_context is None:
941942
return None
@@ -948,20 +949,8 @@ def get_request_from_user_context(
948949

949950
return user_context.get("_default", {}).get("request")
950951

951-
@staticmethod
952-
def is_recipe_initialized(recipe_id: str) -> bool:
953-
"""
954-
Check if a recipe is initialized.
955-
:param recipe_id: The ID of the recipe to check.
956-
:return: Whether the recipe is initialized.
957-
"""
958-
return any(
959-
recipe.get_recipe_id() == recipe_id
960-
for recipe in Supertokens.get_instance().recipe_modules
961-
)
962-
963952

964953
def get_request_from_user_context(
965-
user_context: Optional[Dict[str, Any]],
954+
user_context: Optional[UserContext],
966955
) -> Optional[BaseRequest]:
967956
return Supertokens.get_instance().get_request_from_user_context(user_context)

supertokens_python/syncio/__init__.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
UserIdMappingAlreadyExistsError,
2626
UserIDTypes,
2727
)
28+
from supertokens_python.recipe.session.interfaces import SessionContainer
2829
from supertokens_python.types import User
29-
from supertokens_python.types.base import AccountInfoInput
30+
from supertokens_python.types.base import AccountInfoInput, UserContext
3031

3132

3233
def get_users_oldest_first(
@@ -178,3 +179,32 @@ def list_users_by_account_info(
178179
tenant_id, account_info, do_union_of_account_info, user_context
179180
)
180181
)
182+
183+
184+
def is_recipe_initialized(recipe_id: str) -> bool:
185+
"""
186+
Check if a recipe is initialized.
187+
:param recipe_id: The ID of the recipe to check.
188+
:return: Whether the recipe is initialized.
189+
"""
190+
from supertokens_python.asyncio import (
191+
is_recipe_initialized as async_is_recipe_initialized,
192+
)
193+
194+
return sync(async_is_recipe_initialized(recipe_id))
195+
196+
197+
def get_available_first_factors(
198+
tenant_id: str,
199+
session: Optional[SessionContainer],
200+
user_context: Optional[UserContext],
201+
):
202+
from supertokens_python.asyncio import (
203+
get_available_first_factors as async_get_available_first_factors,
204+
)
205+
206+
return sync(
207+
async_get_available_first_factors(
208+
tenant_id=tenant_id, session=session, user_context=user_context
209+
)
210+
)

tests/test_mfa.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from supertokens_python.recipe.multifactorauth.types import FactorIds
2+
3+
4+
def test_get_all_factors():
5+
"""Test that FactorIds.get_all_factors returns all factors defined in FactorIds class."""
6+
factors_from_dict: list[str] = []
7+
for k, v in FactorIds.__dict__.items():
8+
if (
9+
(not k.startswith("__") or not k.endswith("__"))
10+
and not k.startswith("<")
11+
and isinstance(v, str)
12+
):
13+
factors_from_dict.append(v)
14+
15+
assert factors_from_dict == FactorIds.get_all_factors()

0 commit comments

Comments
 (0)