diff --git a/enterprise/api/v1/urls.py b/enterprise/api/v1/urls.py index be34f8344..f0d1fa62d 100644 --- a/enterprise/api/v1/urls.py +++ b/enterprise/api/v1/urls.py @@ -32,6 +32,8 @@ ) from enterprise.api.v1.views.enterprise_customer_admin import EnterpriseCustomerAdminViewSet from enterprise.api.v1.views.enterprise_sso_users import EnterpriseSSOUserViewSet +from enterprise.api.v1.views.saml_provider_config import SAMLProviderConfigViewSet +from enterprise.api.v1.views.saml_provider_data import SAMLProviderDataViewSet router = DefaultRouter() router.register( @@ -257,3 +259,11 @@ ] urlpatterns += router.urls + +# SAML provider admin endpoints (migrated from openedx-platform third_party_auth). +_saml_router = DefaultRouter() +_saml_router.register(r'auth/saml/v0/provider_config', SAMLProviderConfigViewSet, + basename='enterprise-saml-provider-config') +_saml_router.register(r'auth/saml/v0/provider_data', SAMLProviderDataViewSet, + basename='enterprise-saml-provider-data') +urlpatterns += _saml_router.urls diff --git a/enterprise/api/v1/views/saml_provider_config.py b/enterprise/api/v1/views/saml_provider_config.py new file mode 100644 index 000000000..be324a475 --- /dev/null +++ b/enterprise/api/v1/views/saml_provider_config.py @@ -0,0 +1,102 @@ +""" +Viewset for enterprise SAML provider config administration. +""" +from django.db.utils import IntegrityError +from django.shortcuts import get_list_or_404 +from edx_rbac.mixins import PermissionRequiredMixin +from rest_framework import permissions, status, viewsets +from rest_framework.exceptions import ParseError, ValidationError +from rest_framework.response import Response + +from enterprise.models import EnterpriseCustomer, EnterpriseCustomerIdentityProvider + + +class SAMLProviderConfigViewSet(PermissionRequiredMixin, viewsets.ModelViewSet): + """ + A View to handle SAMLProviderConfig CRUD for enterprise admin users. + + Usage:: + + GET /enterprise/api/v1/auth/saml/v0/provider_config/?enterprise-id= + POST /enterprise/api/v1/auth/saml/v0/provider_config/ + PATCH /enterprise/api/v1/auth/saml/v0/provider_config// + DELETE /enterprise/api/v1/auth/saml/v0/provider_config// + """ + + permission_classes = [permissions.IsAuthenticated] + permission_required = 'enterprise.can_access_admin_dashboard' + + def _get_tpa_classes(self): + # Deferred import — TPA models live in openedx-platform. + from common.djangoapps.third_party_auth.models import SAMLProviderConfig # pylint: disable=import-outside-toplevel + from common.djangoapps.third_party_auth.samlproviderconfig.serializers import SAMLProviderConfigSerializer # pylint: disable=import-outside-toplevel + from common.djangoapps.third_party_auth.utils import convert_saml_slug_provider_id, validate_uuid4_string # pylint: disable=import-outside-toplevel + return SAMLProviderConfig, SAMLProviderConfigSerializer, convert_saml_slug_provider_id, validate_uuid4_string + + def get_serializer_class(self): + _, SAMLProviderConfigSerializer, _, _ = self._get_tpa_classes() + return SAMLProviderConfigSerializer + + def get_queryset(self): + SAMLProviderConfig, _, _, _ = self._get_tpa_classes() + if self.requested_enterprise_uuid is None: + raise ParseError('Required enterprise_customer_uuid is missing') + enterprise_customer_idps = get_list_or_404( + EnterpriseCustomerIdentityProvider, + enterprise_customer__uuid=self.requested_enterprise_uuid + ) + slug_list = [idp.provider_id for idp in enterprise_customer_idps] + saml_config_ids = [ + config.id for config in SAMLProviderConfig.objects.current_set() + if config.provider_id in slug_list + ] + return SAMLProviderConfig.objects.filter(id__in=saml_config_ids) + + def destroy(self, request, *args, **kwargs): + SAMLProviderConfig, _, _, _ = self._get_tpa_classes() + saml_provider_config = self.get_object() + config_id = saml_provider_config.id + provider_config_provider_id = saml_provider_config.provider_id + customer_uuid = self.requested_enterprise_uuid + try: + enterprise_customer = EnterpriseCustomer.objects.get(pk=customer_uuid) + except EnterpriseCustomer.DoesNotExist: + raise ValidationError(f'Enterprise customer not found at uuid: {customer_uuid}') # pylint: disable=raise-missing-from + EnterpriseCustomerIdentityProvider.objects.filter( + enterprise_customer=enterprise_customer, + provider_id=provider_config_provider_id, + ).delete() + SAMLProviderConfig.objects.filter(id=config_id).update(archived=True, enabled=False) + return Response(status=status.HTTP_200_OK, data={'id': config_id}) + + def create(self, request, *args, **kwargs): + SAMLProviderConfig, _, convert_saml_slug_provider_id, validate_uuid4_string = self._get_tpa_classes() + enterprise_customer_uuid = request.data.get('enterprise_customer_uuid') + if not enterprise_customer_uuid or not validate_uuid4_string(enterprise_customer_uuid): + raise ParseError('enterprise_customer_uuid is missing or invalid') + try: + enterprise_customer = EnterpriseCustomer.objects.get(pk=enterprise_customer_uuid) + except EnterpriseCustomer.DoesNotExist: + raise ValidationError(f'Enterprise customer not found at uuid: {enterprise_customer_uuid}') # pylint: disable=raise-missing-from + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + try: + instance = serializer.save() + except IntegrityError: + raise ValidationError('SAML provider config with this entity_id already exists.') # pylint: disable=raise-missing-from + EnterpriseCustomerIdentityProvider.objects.get_or_create( + enterprise_customer=enterprise_customer, + provider_id=convert_saml_slug_provider_id(instance.slug), + ) + headers = self.get_success_headers(serializer.data) + return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + + @property + def requested_enterprise_uuid(self): + return ( + self.request.query_params.get('enterprise-id') or + self.request.data.get('enterprise_customer_uuid') + ) + + def get_permission_object(self): + return self.requested_enterprise_uuid diff --git a/enterprise/api/v1/views/saml_provider_data.py b/enterprise/api/v1/views/saml_provider_data.py new file mode 100644 index 000000000..61127362a --- /dev/null +++ b/enterprise/api/v1/views/saml_provider_data.py @@ -0,0 +1,120 @@ +""" +Viewset for enterprise SAML provider data administration. +""" +import logging + +from django.http import Http404 +from django.shortcuts import get_object_or_404 +from edx_rbac.mixins import PermissionRequiredMixin +from requests.exceptions import HTTPError, MissingSchema, SSLError +from rest_framework import permissions, status, viewsets +from rest_framework.decorators import action +from rest_framework.exceptions import ParseError +from rest_framework.response import Response + +from enterprise.models import EnterpriseCustomerIdentityProvider + +log = logging.getLogger(__name__) + + +class SAMLProviderDataViewSet(PermissionRequiredMixin, viewsets.ModelViewSet): + """ + A View to handle SAMLProviderData CRUD for enterprise admin users. + + Usage:: + + GET /enterprise/api/v1/auth/saml/v0/provider_data/?enterprise-id= + POST /enterprise/api/v1/auth/saml/v0/provider_data/ + PATCH /enterprise/api/v1/auth/saml/v0/provider_data// + DELETE /enterprise/api/v1/auth/saml/v0/provider_data// + POST /enterprise/api/v1/auth/saml/v0/provider_data/sync_provider_data + """ + + permission_classes = [permissions.IsAuthenticated] + permission_required = 'enterprise.can_access_admin_dashboard' + + def _get_tpa_classes(self): + # Deferred import — TPA models live in openedx-platform. + from common.djangoapps.third_party_auth.models import SAMLProviderConfig, SAMLProviderData # pylint: disable=import-outside-toplevel + from common.djangoapps.third_party_auth.samlproviderdata.serializers import SAMLProviderDataSerializer # pylint: disable=import-outside-toplevel + from common.djangoapps.third_party_auth.utils import ( # pylint: disable=import-outside-toplevel + convert_saml_slug_provider_id, + create_or_update_bulk_saml_provider_data, + fetch_metadata_xml, + parse_metadata_xml, + validate_uuid4_string, + ) + return ( + SAMLProviderConfig, SAMLProviderData, SAMLProviderDataSerializer, + convert_saml_slug_provider_id, create_or_update_bulk_saml_provider_data, + fetch_metadata_xml, parse_metadata_xml, validate_uuid4_string, + ) + + def get_serializer_class(self): + _, _, SAMLProviderDataSerializer, *_ = self._get_tpa_classes() + return SAMLProviderDataSerializer + + def get_queryset(self): + SAMLProviderConfig, SAMLProviderData, _, convert_saml_slug_provider_id, *_ = self._get_tpa_classes() + if self.requested_enterprise_uuid is None: + raise ParseError('Required enterprise_customer_uuid is missing') + enterprise_customer_idp = get_object_or_404( + EnterpriseCustomerIdentityProvider, + enterprise_customer__uuid=self.requested_enterprise_uuid + ) + try: + saml_provider = SAMLProviderConfig.objects.current_set().get( + slug=convert_saml_slug_provider_id(enterprise_customer_idp.provider_id)) + except SAMLProviderConfig.DoesNotExist: + raise Http404('No matching SAML provider found.') # pylint: disable=raise-missing-from + provider_data_id = self.request.parser_context.get('kwargs', {}).get('pk') + if provider_data_id: + return SAMLProviderData.objects.filter(id=provider_data_id) + return SAMLProviderData.objects.filter(entity_id=saml_provider.entity_id) + + @property + def requested_enterprise_uuid(self): + return ( + self.request.query_params.get('enterprise-id') or + self.request.data.get('enterprise_customer_uuid') + ) + + def get_permission_object(self): + return self.requested_enterprise_uuid + + @action(detail=False, methods=['post'], url_path='sync_provider_data') + def sync_provider_data(self, request): + (SAMLProviderConfig, _, _, convert_saml_slug_provider_id, create_or_update_bulk_saml_provider_data, + fetch_metadata_xml, parse_metadata_xml, validate_uuid4_string) = self._get_tpa_classes() + enterprise_customer_uuid = request.data.get('enterprise_customer_uuid') + if not validate_uuid4_string(enterprise_customer_uuid): + raise ParseError('enterprise_customer_uuid is not a valid uuid4') + enterprise_customer_idp = get_object_or_404( + EnterpriseCustomerIdentityProvider, + enterprise_customer__uuid=enterprise_customer_uuid + ) + try: + saml_provider = SAMLProviderConfig.objects.current_set().get( + slug=convert_saml_slug_provider_id(enterprise_customer_idp.provider_id)) + except SAMLProviderConfig.DoesNotExist: + raise Http404('No matching SAML provider found.') # pylint: disable=raise-missing-from + metadata_url = saml_provider.metadata_source + try: + xml = fetch_metadata_xml(metadata_url) + except (SSLError, MissingSchema, HTTPError) as exc: + return Response( + data={'error': f'Failed to fetch metadata XML: {exc}'}, + status=status.HTTP_400_BAD_REQUEST, + ) + result = parse_metadata_xml(xml, saml_provider.entity_id) + if result is None: + return Response( + data={'error': 'Failed to parse metadata XML.'}, + status=status.HTTP_400_BAD_REQUEST, + ) + public_key, sso_url, expires_at = result + create_or_update_bulk_saml_provider_data(public_key, sso_url, expires_at, saml_provider.entity_id) + return Response( + data={'message': 'Synced provider data successfully.'}, + status=status.HTTP_200_OK, + )