diff --git a/.github/workflows/add-depr-ticket-to-depr-board.yml b/.github/workflows/add-depr-ticket-to-depr-board.yml deleted file mode 100644 index 250e394abc11..000000000000 --- a/.github/workflows/add-depr-ticket-to-depr-board.yml +++ /dev/null @@ -1,19 +0,0 @@ -# Run the workflow that adds new tickets that are either: -# - labelled "DEPR" -# - title starts with "[DEPR]" -# - body starts with "Proposal Date" (this is the first template field) -# to the org-wide DEPR project board - -name: Add newly created DEPR issues to the DEPR project board - -on: - issues: - types: [opened] - -jobs: - routeissue: - uses: openedx/.github/.github/workflows/add-depr-ticket-to-depr-board.yml@master - secrets: - GITHUB_APP_ID: ${{ secrets.GRAPHQL_AUTH_APP_ID }} - GITHUB_APP_PRIVATE_KEY: ${{ secrets.GRAPHQL_AUTH_APP_PEM }} - SLACK_BOT_TOKEN: ${{ secrets.SLACK_ISSUE_BOT_TOKEN }} diff --git a/.github/workflows/add-remove-label-on-comment.yml b/.github/workflows/add-remove-label-on-comment.yml deleted file mode 100644 index 0f369db7d293..000000000000 --- a/.github/workflows/add-remove-label-on-comment.yml +++ /dev/null @@ -1,20 +0,0 @@ -# This workflow runs when a comment is made on the ticket -# If the comment starts with "label: " it tries to apply -# the label indicated in rest of comment. -# If the comment starts with "remove label: ", it tries -# to remove the indicated label. -# Note: Labels are allowed to have spaces and this script does -# not parse spaces (as often a space is legitimate), so the command -# "label: really long lots of words label" will apply the -# label "really long lots of words label" - -name: Allows for the adding and removing of labels via comment - -on: - issue_comment: - types: [created] - -jobs: - add_remove_labels: - uses: openedx/.github/.github/workflows/add-remove-label-on-comment.yml@master - diff --git a/.github/workflows/js-tests.yml b/.github/workflows/js-tests.yml index 94a1368e96a5..972435a820e3 100644 --- a/.github/workflows/js-tests.yml +++ b/.github/workflows/js-tests.yml @@ -4,7 +4,7 @@ on: pull_request: push: branches: - - master + - release-ulmo jobs: run_tests: diff --git a/.github/workflows/lint-imports.yml b/.github/workflows/lint-imports.yml index baf914298be2..17cc4ea0a935 100644 --- a/.github/workflows/lint-imports.yml +++ b/.github/workflows/lint-imports.yml @@ -4,7 +4,7 @@ on: pull_request: push: branches: - - master + - release-ulmo jobs: lint-imports: diff --git a/.github/workflows/lockfileversion-check.yml b/.github/workflows/lockfileversion-check.yml index 736f1f98de13..ed0b7f97de6d 100644 --- a/.github/workflows/lockfileversion-check.yml +++ b/.github/workflows/lockfileversion-check.yml @@ -5,7 +5,7 @@ name: Lockfile Version check on: push: branches: - - master + - release-ulmo pull_request: jobs: diff --git a/.github/workflows/migrations-check.yml b/.github/workflows/migrations-check.yml index cd4d09589c12..686d8d9086b0 100644 --- a/.github/workflows/migrations-check.yml +++ b/.github/workflows/migrations-check.yml @@ -5,7 +5,7 @@ on: pull_request: push: branches: - - master + - release-ulmo jobs: check_migrations: diff --git a/.github/workflows/quality-checks.yml b/.github/workflows/quality-checks.yml index 3f4cbeeb4df9..ee67e3903569 100644 --- a/.github/workflows/quality-checks.yml +++ b/.github/workflows/quality-checks.yml @@ -4,8 +4,7 @@ on: pull_request: push: branches: - - master - - open-release/lilac.master + - release-ulmo jobs: run_tests: diff --git a/.github/workflows/semgrep.yml b/.github/workflows/semgrep.yml index 520cd23a678b..d9d32ab9d36d 100644 --- a/.github/workflows/semgrep.yml +++ b/.github/workflows/semgrep.yml @@ -9,7 +9,7 @@ on: pull_request: push: branches: - - master + - release-ulmo jobs: run_semgrep: diff --git a/.github/workflows/shellcheck.yml b/.github/workflows/shellcheck.yml index 2e5b04bcc2ff..a8df63a20d57 100644 --- a/.github/workflows/shellcheck.yml +++ b/.github/workflows/shellcheck.yml @@ -9,7 +9,7 @@ on: pull_request: push: branches: - - master + - release-ulmo permissions: contents: read diff --git a/.github/workflows/static-assets-check.yml b/.github/workflows/static-assets-check.yml index 43cb597c16d7..3a0afa76deb7 100644 --- a/.github/workflows/static-assets-check.yml +++ b/.github/workflows/static-assets-check.yml @@ -4,7 +4,7 @@ on: pull_request: push: branches: - - master + - release-ulmo jobs: static_assets_check: diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index d8b8c26cd049..036d411fa1a0 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -4,7 +4,7 @@ on: pull_request: push: branches: - - master + - release-ulmo concurrency: # We only need to be running tests for the latest commit on each PR @@ -74,6 +74,12 @@ jobs: run: | sudo apt-get update && sudo apt-get install libmysqlclient-dev libxmlsec1-dev lynx + - name: Upgrade Docker + run: | + sudo apt-get update + sudo apt-get install --only-upgrade docker-ce docker-ce-cli containerd.io + docker --version + # We pull this image a lot, and Dockerhub will rate limit us if we pull too often. # This is an attempt to cache the image for better performance and to work around that. # It will cache all pulled images, so if we add new images to this we'll need to update the key. @@ -83,9 +89,10 @@ jobs: key: docker-${{ runner.os }}-mongo-${{ matrix.mongo-version }} - name: Start MongoDB - uses: supercharge/mongodb-github-action@1.12.0 - with: - mongodb-version: ${{ matrix.mongo-version }} + run: | + docker run -d -p 27017:27017 --name mongodb mongo:${{ matrix.mongo-version }} + sleep 10 + docker ps - name: Setup Python uses: actions/setup-python@v5 @@ -124,26 +131,26 @@ jobs: shell: bash run: | cd test_root/log - mv pytest_warnings.json pytest_warnings_${{ matrix.shard_name }}.json + mv pytest_warnings.json pytest_warnings_${{ matrix.shard_name }}_${{ matrix.python-version }}_${{ matrix.django-version }}_${{ matrix.mongo-version }}_${{ matrix.os-version }}.json - name: save pytest warnings json file if: success() uses: actions/upload-artifact@v4 with: - name: pytest-warnings-json-${{ matrix.shard_name }} + name: pytest-warnings-json-${{ matrix.shard_name }}-${{ matrix.python-version }}-${{ matrix.django-version }}-${{ matrix.mongo-version }}-${{ matrix.os-version }} path: | test_root/log/pytest_warnings*.json overwrite: true - name: Renaming coverage data file run: | - mv reports/.coverage reports/${{ matrix.shard_name }}.coverage + mv reports/.coverage reports/${{ matrix.shard_name }}_${{ matrix.python-version }}_${{ matrix.django-version }}_${{ matrix.mongo-version }}_${{ matrix.os-version }}.coverage - name: Upload coverage uses: actions/upload-artifact@v4 with: - name: coverage-${{ matrix.shard_name }} - path: reports/${{ matrix.shard_name }}.coverage + name: coverage-${{ matrix.shard_name }}-${{ matrix.python-version }}-${{ matrix.django-version }}-${{ matrix.mongo-version }}-${{ matrix.os-version }} + path: reports/${{ matrix.shard_name }}_${{ matrix.python-version }}_${{ matrix.django-version }}_${{ matrix.mongo-version }}_${{ matrix.os-version }}.coverage overwrite: true collect-and-verify: diff --git a/.github/workflows/units-test-scripts-structures-pruning.yml b/.github/workflows/units-test-scripts-structures-pruning.yml index 14a01b592308..ef408cfe66ec 100644 --- a/.github/workflows/units-test-scripts-structures-pruning.yml +++ b/.github/workflows/units-test-scripts-structures-pruning.yml @@ -4,7 +4,7 @@ on: pull_request: push: branches: - - master + - release-ulmo jobs: test: diff --git a/.github/workflows/units-test-scripts-user-retirement.yml b/.github/workflows/units-test-scripts-user-retirement.yml index 889c43a64a48..b43bbf46b0d4 100644 --- a/.github/workflows/units-test-scripts-user-retirement.yml +++ b/.github/workflows/units-test-scripts-user-retirement.yml @@ -4,7 +4,7 @@ on: pull_request: push: branches: - - master + - release-ulmo jobs: test: diff --git a/.github/workflows/update-geolite-database.yml b/.github/workflows/update-geolite-database.yml index 484fa167a371..d9ad767ce57e 100644 --- a/.github/workflows/update-geolite-database.yml +++ b/.github/workflows/update-geolite-database.yml @@ -8,7 +8,7 @@ on: branch: description: "Target branch against which to create PR" required: false - default: "master" + default: "release-ulmo" env: MAXMIND_URL: "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-Country&license_key=${{ secrets.MAXMIND_LICENSE_KEY }}&suffix=tar.gz" @@ -69,7 +69,7 @@ jobs: - name: Create a branch, commit the code and make a PR id: create-pr run: | - BRANCH="${{ github.actor }}/geoip2-bot-update-country-database-$(echo "${{ github.sha }}" | cut -c 1-7)" + BRANCH="${{ github.actor }}/geoip2-bot-update-country-database-${{ github.run_id }}" git checkout -b $BRANCH git add . git status @@ -79,13 +79,13 @@ jobs: --title "Update GeoLite Database" \ --body "PR generated by workflow `${{ github.workflow }}` on behalf of @${{ github.actor }}." \ --head $BRANCH \ - --base 'master' \ - --reviewer 'feanil' \ + --base 'release-ulmo' \ + --reviewer 'edx/orbi-bom' \ | grep -o 'https://github.com/.*/pull/[0-9]*') echo "PR Created: ${PR_URL}" echo "pull-request-url=$PR_URL" >> $GITHUB_OUTPUT env: - GH_TOKEN: ${{ github.token }} + GH_TOKEN: ${{ secrets.GH_PAT_WITH_ORG }} - name: Job summary run: | diff --git a/.github/workflows/upgrade-one-python-dependency.yml b/.github/workflows/upgrade-one-python-dependency.yml index 3f9678593c25..1d8170961865 100644 --- a/.github/workflows/upgrade-one-python-dependency.yml +++ b/.github/workflows/upgrade-one-python-dependency.yml @@ -6,7 +6,7 @@ on: branch: description: "Target branch to create requirements PR against" required: true - default: "master" + default: "release-ulmo" type: string package: description: "Name of package to upgrade" diff --git a/.github/workflows/upgrade-python-requirements.yml b/.github/workflows/upgrade-python-requirements.yml index cbb70b06b79d..90c6be00744c 100644 --- a/.github/workflows/upgrade-python-requirements.yml +++ b/.github/workflows/upgrade-python-requirements.yml @@ -8,7 +8,7 @@ on: branch: description: "Target branch to create requirements PR against" required: true - default: "master" + default: "release-ulmo" jobs: call-upgrade-python-requirements-workflow: # Don't run the weekly upgrade job on forks -- it will send a weekly failure email. diff --git a/.github/workflows/verify-dunder-init.yml b/.github/workflows/verify-dunder-init.yml index c3248def2f33..462f0e06af0b 100644 --- a/.github/workflows/verify-dunder-init.yml +++ b/.github/workflows/verify-dunder-init.yml @@ -3,7 +3,7 @@ name: Verify Dunder __init__.py Files on: pull_request: branches: - - master + - release-ulmo jobs: verify_dunder_init: diff --git a/.gitignore b/.gitignore index a5d5252de705..5587df58faac 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,10 @@ requirements/edx/private.in requirements/edx/private.txt lms/envs/private.py cms/envs/private.py +.venv/ +CLAUDE.md +.claude/ +AGENTS.md # end-noclean ### Python artifacts diff --git a/Makefile b/Makefile index 6c525a57b67e..66c3608af038 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,8 @@ pull_translations: clean_translations ## pull translations via atlas make pull_plugin_translations atlas pull $(ATLAS_OPTIONS) \ translations/edx-platform/conf/locale:conf/locale \ - translations/studio-frontend/src/i18n/messages:conf/plugins-locale/studio-frontend + translations/studio-frontend/src/i18n/messages:conf/plugins-locale/studio-frontend \ + $(ATLAS_EXTRA_SOURCES) python manage.py lms compilemessages python manage.py lms compilejsi18n python manage.py cms compilejsi18n diff --git a/README.rst b/README.rst index dcd6e32c4998..a67e281f041b 100644 --- a/README.rst +++ b/README.rst @@ -74,7 +74,7 @@ OS: * Ubuntu 24.04 -Interperters/Tools: +Interpreters/Tools: * Python 3.11 diff --git a/cms/djangoapps/contentstore/exams.py b/cms/djangoapps/contentstore/exams.py index 6b25147c6abc..8a4ddc09425e 100644 --- a/cms/djangoapps/contentstore/exams.py +++ b/cms/djangoapps/contentstore/exams.py @@ -74,13 +74,13 @@ def register_exams(course_key): # Exams in courses not using an LTI based proctoring provider should use the original definition of due_date # from contentstore/proctoring.py. These exams are powered by the edx-proctoring plugin and not the edx-exams # microservice. + is_instructor_paced = not course.self_paced if course.proctoring_provider == 'lti_external': - due_date = ( - timed_exam.due.isoformat() if timed_exam.due - else (course.end.isoformat() if course.end else None) - ) + due_date_source = timed_exam.due if is_instructor_paced else course.end else: - due_date = timed_exam.due if not course.self_paced else None + due_date_source = timed_exam.due if is_instructor_paced else None + + due_date = due_date_source.isoformat() if due_date_source else None exams_list.append({ 'course_id': str(course_key), diff --git a/cms/djangoapps/contentstore/helpers.py b/cms/djangoapps/contentstore/helpers.py index 2bdbabc7df8c..91236a4dade9 100644 --- a/cms/djangoapps/contentstore/helpers.py +++ b/cms/djangoapps/contentstore/helpers.py @@ -745,10 +745,10 @@ def _import_file_into_course( if thumbnail_content is not None: content.thumbnail_location = thumbnail_location contentstore().save(content) - return True, {clipboard_file_path: f"static/{import_path}"} + return True, {clipboard_file_path: filename if not import_path else f"static/{import_path}"} elif current_file.content_digest == file_data_obj.md5_hash: - # The file already exists and matches exactly, so no action is needed except substitutions - return None, {clipboard_file_path: f"static/{import_path}"} + # The file already exists and matches exactly, so no action is needed + return None, {} else: # There is a conflict with some other file that has the same name. return False, {} diff --git a/cms/djangoapps/contentstore/rest_api/v1/serializers/course_waffle_flags.py b/cms/djangoapps/contentstore/rest_api/v1/serializers/course_waffle_flags.py index 3efb7b6226d4..fde90163f803 100644 --- a/cms/djangoapps/contentstore/rest_api/v1/serializers/course_waffle_flags.py +++ b/cms/djangoapps/contentstore/rest_api/v1/serializers/course_waffle_flags.py @@ -11,6 +11,7 @@ class CourseWaffleFlagsSerializer(serializers.Serializer): """ Serializer for course waffle flags """ + use_new_home_page = serializers.SerializerMethodField() use_new_custom_pages = serializers.SerializerMethodField() use_new_schedule_details_page = serializers.SerializerMethodField() @@ -31,6 +32,7 @@ class CourseWaffleFlagsSerializer(serializers.Serializer): use_react_markdown_editor = serializers.SerializerMethodField() use_video_gallery_flow = serializers.SerializerMethodField() enable_course_optimizer_check_prev_run_links = serializers.SerializerMethodField() + enable_unit_expanded_view = serializers.SerializerMethodField() def get_course_key(self): """ @@ -175,3 +177,10 @@ def get_enable_course_optimizer_check_prev_run_links(self, obj): """ course_key = self.get_course_key() return toggles.enable_course_optimizer_check_prev_run_links(course_key) + + def get_enable_unit_expanded_view(self, obj): + """ + Method to get the enable_unit_expanded_view waffle flag + """ + course_key = self.get_course_key() + return toggles.enable_unit_expanded_view(course_key) diff --git a/cms/djangoapps/contentstore/rest_api/v1/urls.py b/cms/djangoapps/contentstore/rest_api/v1/urls.py index 685a81d778ce..8a94f0b0e040 100644 --- a/cms/djangoapps/contentstore/rest_api/v1/urls.py +++ b/cms/djangoapps/contentstore/rest_api/v1/urls.py @@ -25,6 +25,7 @@ HomePageView, ProctoredExamSettingsView, ProctoringErrorsView, + UnitComponentsView, VideoDownloadView, VideoUsageView, vertical_container_children_redirect_view, @@ -144,6 +145,11 @@ CourseWaffleFlagsView.as_view(), name="course_waffle_flags" ), + re_path( + fr'^unit_handler/{settings.USAGE_KEY_PATTERN}$', + UnitComponentsView.as_view(), + name="unit_components" + ), # Authoring API # Do not use under v1 yet (Nov. 23). The Authoring API is still experimental and the v0 versions should be used diff --git a/cms/djangoapps/contentstore/rest_api/v1/views/__init__.py b/cms/djangoapps/contentstore/rest_api/v1/views/__init__.py index d4fcfd5f2e3f..25c0157904e9 100644 --- a/cms/djangoapps/contentstore/rest_api/v1/views/__init__.py +++ b/cms/djangoapps/contentstore/rest_api/v1/views/__init__.py @@ -14,9 +14,6 @@ from .proctoring import ProctoredExamSettingsView, ProctoringErrorsView from .settings import CourseSettingsView from .textbooks import CourseTextbooksView +from .unit_handler import UnitComponentsView from .vertical_block import ContainerHandlerView, vertical_container_children_redirect_view -from .videos import ( - CourseVideosView, - VideoDownloadView, - VideoUsageView, -) +from .videos import CourseVideosView, VideoDownloadView, VideoUsageView diff --git a/cms/djangoapps/contentstore/rest_api/v1/views/tests/test_course_waffle_flags.py b/cms/djangoapps/contentstore/rest_api/v1/views/tests/test_course_waffle_flags.py index f45cc48810d6..a788ce4af3b5 100644 --- a/cms/djangoapps/contentstore/rest_api/v1/views/tests/test_course_waffle_flags.py +++ b/cms/djangoapps/contentstore/rest_api/v1/views/tests/test_course_waffle_flags.py @@ -38,6 +38,7 @@ class CourseWaffleFlagsViewTest(CourseTestCase): "use_react_markdown_editor": False, "use_video_gallery_flow": False, "enable_course_optimizer_check_prev_run_links": False, + "enable_unit_expanded_view": False, } def setUp(self): diff --git a/cms/djangoapps/contentstore/rest_api/v1/views/unit_handler.py b/cms/djangoapps/contentstore/rest_api/v1/views/unit_handler.py new file mode 100644 index 000000000000..0740152e4fab --- /dev/null +++ b/cms/djangoapps/contentstore/rest_api/v1/views/unit_handler.py @@ -0,0 +1,129 @@ +"""API Views for unit components handler""" + +import logging + +import edx_api_doc_tools as apidocs +from django.http import HttpResponseBadRequest +from opaque_keys.edx.keys import UsageKey +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.views import APIView + +from cms.djangoapps.contentstore.rest_api.v1.mixins import ContainerHandlerMixin +from openedx.core.lib.api.view_utils import view_auth_classes +from xmodule.modulestore.django import modulestore +from xmodule.modulestore.exceptions import ItemNotFoundError + +log = logging.getLogger(__name__) + + +@view_auth_classes(is_authenticated=True) +class UnitComponentsView(APIView, ContainerHandlerMixin): + """ + View to get all components in a unit by usage key. + """ + + @apidocs.schema( + parameters=[ + apidocs.string_parameter( + "usage_key_string", + apidocs.ParameterLocation.PATH, + description="Unit usage key", + ), + ], + responses={ + 200: "List of components in the unit", + 400: "Invalid usage key or unit not found.", + 401: "The requester is not authenticated.", + 404: "The requested unit does not exist.", + }, + ) + def get(self, request: Request, usage_key_string: str): + """ + Get all components in a unit. + + **Example Request** + + GET /api/contentstore/v1/unit_handler/{usage_key_string} + + **Response Values** + + If the request is successful, an HTTP 200 "OK" response is returned. + + The HTTP 200 response contains a dict with a list of all components + in the unit, including their display names, block types, and block IDs. + + **Example Response** + + ```json + { + "unit_id": "block-v1:edX+DemoX+Demo_Course+type@vertical+block@vertical_id", + "display_name": "My Unit", + "components": [ + { + "block_id": "block-v1:edX+DemoX+Demo_Course+type@video+block@video_id", + "block_type": "video", + "display_name": "Introduction Video" + }, + { + "block_id": "block-v1:edX+DemoX+Demo_Course+type@html+block@html_id", + "block_type": "html", + "display_name": "Text Content" + }, + { + "block_id": "block-v1:edX+DemoX+Demo_Course+type@problem+block@problem_id", + "block_type": "problem", + "display_name": "Practice Problem" + } + ] + } + ``` + """ + try: + usage_key = UsageKey.from_string(usage_key_string) + except Exception as e: # pylint: disable=broad-exception-caught + log.error(f"Invalid usage key: {usage_key_string}, error: {str(e)}") + return HttpResponseBadRequest("Invalid usage key format") + + try: + # Get the unit xblock + unit_xblock = modulestore().get_item(usage_key) + + # Verify it's a vertical (unit) + if unit_xblock.category != "vertical": + return HttpResponseBadRequest( + "The provided usage key is not a unit (vertical)" + ) + + components = [] + + # Get all children (components) of the unit + if unit_xblock.has_children: + for child_usage_key in unit_xblock.children: + try: + child_xblock = modulestore().get_item(child_usage_key) + components.append( + { + "block_id": str(child_xblock.location), + "block_type": child_xblock.category, + "display_name": child_xblock.display_name_with_default, + } + ) + except ItemNotFoundError: + log.warning(f"Child block not found: {child_usage_key}") + continue + + response_data = { + "unit_id": str(usage_key), + "display_name": unit_xblock.display_name_with_default, + "components": components, + } + + return Response(response_data) + + except ItemNotFoundError: + log.error(f"Unit not found: {usage_key_string}") + return HttpResponseBadRequest("Unit not found") + except Exception as e: # pylint: disable=broad-exception-caught + log.error(f"Error retrieving unit components: {str(e)}") + return HttpResponseBadRequest(f"Error retrieving unit components: {str(e)}") diff --git a/cms/djangoapps/contentstore/tests/test_course_listing.py b/cms/djangoapps/contentstore/tests/test_course_listing.py index e46b493b7b39..a2b6f07d15ef 100644 --- a/cms/djangoapps/contentstore/tests/test_course_listing.py +++ b/cms/djangoapps/contentstore/tests/test_course_listing.py @@ -24,8 +24,10 @@ get_courses_accessible_to_user ) from common.djangoapps.course_action_state.models import CourseRerunState +from common.djangoapps.student.models.user import CourseAccessRole from common.djangoapps.student.roles import ( CourseInstructorRole, + CourseLimitedStaffRole, CourseStaffRole, GlobalStaff, OrgInstructorRole, @@ -188,6 +190,48 @@ def test_staff_course_listing(self): with self.assertNumQueries(2): list(_accessible_courses_summary_iter(self.request)) + def test_course_limited_staff_course_listing(self): + # Setup a new course + course_location = self.store.make_course_key('Org', 'CreatedCourse', 'Run') + CourseFactory.create( + org=course_location.org, + number=course_location.course, + run=course_location.run + ) + course = CourseOverviewFactory.create(id=course_location, org=course_location.org) + + # Add the user as a course_limited_staff on the course + CourseLimitedStaffRole(course.id).add_users(self.user) + self.assertTrue(CourseLimitedStaffRole(course.id).has_user(self.user)) + + # Fetch accessible courses list & verify their count + courses_list_by_staff, __ = get_courses_accessible_to_user(self.request) + + # Limited Course Staff should not be able to list courses in Studio + assert len(list(courses_list_by_staff)) == 0 + + def test_org_limited_staff_course_listing(self): + + # Setup a new course + course_location = self.store.make_course_key('Org', 'CreatedCourse', 'Run') + CourseFactory.create( + org=course_location.org, + number=course_location.course, + run=course_location.run + ) + course = CourseOverviewFactory.create(id=course_location, org=course_location.org) + + # Add a user as course_limited_staff on the org + # This is not possible using the course roles classes but is possible via Django admin so we + # insert a row into the model directly to test that scenario. + CourseAccessRole.objects.create(user=self.user, org=course_location.org, role=CourseLimitedStaffRole.ROLE) + + # Fetch accessible courses list & verify their count + courses_list_by_staff, __ = get_courses_accessible_to_user(self.request) + + # Limited Course Staff should not be able to list courses in Studio + assert len(list(courses_list_by_staff)) == 0 + def test_get_course_list_with_invalid_course_location(self): """ Test getting courses with invalid course location (course deleted from modulestore). diff --git a/cms/djangoapps/contentstore/tests/test_exams.py b/cms/djangoapps/contentstore/tests/test_exams.py index 798b5e51fd80..823038957714 100644 --- a/cms/djangoapps/contentstore/tests/test_exams.py +++ b/cms/djangoapps/contentstore/tests/test_exams.py @@ -65,16 +65,21 @@ def _get_exam_due_date(self, course, sequential): Return the expected exam due date for the exam, based on the selected course proctoring provider and the exam due date or the course end date. + This is a copy of the due date computation logic in register_exams function. + Arguments: * course: the course that the exam subsection is in; may have a course.end attribute * sequential: the exam subsection; may have a sequential.due attribute """ + is_instructor_paced = not course.self_paced if course.proctoring_provider == 'lti_external': - return sequential.due.isoformat() if sequential.due else (course.end.isoformat() if course.end else None) - elif course.self_paced: - return None + due_date_source = sequential.due if is_instructor_paced else course.end else: - return sequential.due + due_date_source = sequential.due if is_instructor_paced else None + + due_date = due_date_source.isoformat() if due_date_source else None + + return due_date @ddt.data(*(tuple(base) + (extra,) for base, extra in itertools.product( [ @@ -185,14 +190,13 @@ def test_feature_flag_off(self, mock_patch_course_exams): def test_no_due_dates(self, is_self_paced, course_end_date, proctoring_provider, mock_patch_course_exams): """ Test that the the correct due date is registered for the exam when the subsection does not have a due date, - depending on the proctoring provider. + depending on the proctoring provider and course pacing type. * lti_external - * The course end date is registered as the due date when the subsection does not have a due date for both - self-paced and instructor-paced exams. + * If the course is instructor-paced, the exam due date is the subsection due date if it exists, else None. + * If the course is self-paced, the exam due date is the course end date if it exists, else None. * not lti_external - * None is registered as the due date when the subsection does not have a due date for both - self-paced and instructor-paced exams. + * The exam due date is always the subsection due date if it exists, else None. """ self.course.self_paced = is_self_paced self.course.end = course_end_date @@ -222,25 +226,17 @@ def test_no_due_dates(self, is_self_paced, course_end_date, proctoring_provider, @ddt.data(*itertools.product((True, False), ('lti_external', 'null'))) @ddt.unpack @freeze_time('2024-01-01') - def test_subsection_due_date_prioritized(self, is_self_paced, proctoring_provider, mock_patch_course_exams): + def test_subsection_due_date_prioritized_instructor_paced( + self, + is_self_paced, + proctoring_provider, + mock_patch_course_exams + ): """ - Test that the subsection due date is registered as the due date when both the subsection has a due date and the - course has an end date for both self-paced and instructor-paced exams. - - Test that the the correct due date is registered for the exam when the subsection has a due date, depending on - the proctoring provider. - - * lti_external - * The subsection due date is registered as the due date when both the subsection has a due date and the - course has an end date for both self-paced and instructor-paced exams - * not lti_external - * None is registered as the due date when both the subsection has a due date and the course has an end date - for self-paced exams. - * The subsection due date is registered as the due date when both the subsection has a due date and the - course has an end date for instructor-paced exams. + Test that exam due date is computed correctly. """ self.course.self_paced = is_self_paced - self.course.end = datetime(2035, 1, 1, 0, 0) + self.course.end = datetime(2035, 1, 1, 0, 0, tzinfo=timezone.utc) self.course.proctoring_provider = proctoring_provider self.course = self.update_course(self.course, 1) @@ -260,7 +256,7 @@ def test_subsection_due_date_prioritized(self, is_self_paced, proctoring_provide ) listen_for_course_publish(self, self.course.id) - called_exams, called_course = mock_patch_course_exams.call_args[0] + called_exams, _ = mock_patch_course_exams.call_args[0] expected_due_date = self._get_exam_due_date(self.course, sequence) diff --git a/cms/djangoapps/contentstore/toggles.py b/cms/djangoapps/contentstore/toggles.py index c287f8c4dbec..c3dba4a6f4a4 100644 --- a/cms/djangoapps/contentstore/toggles.py +++ b/cms/djangoapps/contentstore/toggles.py @@ -682,3 +682,23 @@ def enable_course_optimizer_check_prev_run_links(course_key): Returns a boolean if previous run course optimizer feature is enabled for the given course. """ return ENABLE_COURSE_OPTIMIZER_CHECK_PREV_RUN_LINKS.is_enabled(course_key) + + +# .. toggle_name: contentstore.enable_unit_expanded_view +# .. toggle_implementation: CourseWaffleFlag +# .. toggle_default: False +# .. toggle_description: When enabled, the Unit Expanded View feature in the Course Outline is activated. +# .. toggle_use_cases: temporary +# .. toggle_creation_date: 2026-01-01 +# .. toggle_target_removal_date: 2026-06-01 +# .. toggle_tickets: TNL2-473 +ENABLE_UNIT_EXPANDED_VIEW = CourseWaffleFlag( + f"{CONTENTSTORE_NAMESPACE}.enable_unit_expanded_view", __name__ +) + + +def enable_unit_expanded_view(course_key): + """ + Returns a boolean if the Unit Expanded View feature is enabled for the given course. + """ + return ENABLE_UNIT_EXPANDED_VIEW.is_enabled(course_key) diff --git a/cms/djangoapps/contentstore/video_storage_handlers.py b/cms/djangoapps/contentstore/video_storage_handlers.py index 87086c9951ac..c39970d56d20 100644 --- a/cms/djangoapps/contentstore/video_storage_handlers.py +++ b/cms/djangoapps/contentstore/video_storage_handlers.py @@ -13,12 +13,11 @@ import shutil import pathlib import zipfile - from contextlib import closing from datetime import datetime, timedelta from uuid import uuid4 -from boto.s3.connection import S3Connection -from boto import s3 + +import boto3 from django.conf import settings from django.contrib.staticfiles.storage import staticfiles_storage from django.http import FileResponse, HttpResponseNotFound, StreamingHttpResponse @@ -55,10 +54,7 @@ from common.djangoapps.util.json_request import JsonResponse from openedx.core.djangoapps.video_config.models import VideoTranscriptEnabledFlag from openedx.core.djangoapps.video_config.toggles import PUBLIC_VIDEO_SHARE -from openedx.core.djangoapps.video_pipeline.config.waffle import ( - DEPRECATE_YOUTUBE, - ENABLE_DEVSTACK_VIDEO_UPLOADS, -) +from openedx.core.djangoapps.video_pipeline.config.waffle import DEPRECATE_YOUTUBE from openedx.core.djangoapps.waffle_utils import CourseWaffleFlag from xmodule.modulestore.django import modulestore # lint-amnesty, pylint: disable=wrong-import-order @@ -812,7 +808,8 @@ def videos_post(course, request): if error: return {'error': error}, 400 - bucket = storage_service_bucket() + s3_client = boto3.client('s3') + req_files = data['files'] resp_files = [] @@ -826,7 +823,6 @@ def videos_post(course, request): return {'error': error_msg}, 400 edx_video_id = str(uuid4()) - key = storage_service_key(bucket, file_name=edx_video_id) metadata_list = [ ('client_video_id', file_name), @@ -846,12 +842,15 @@ def videos_post(course, request): if transcript_preferences is not None: metadata_list.append(('transcript_preferences', json.dumps(transcript_preferences))) - for metadata_name, value in metadata_list: - key.set_metadata(metadata_name, value) - upload_url = key.generate_url( - KEY_EXPIRATION_IN_SECONDS, - 'PUT', - headers={'Content-Type': req_file['content_type']} + upload_url = s3_client.generate_presigned_url( + ClientMethod='put_object', + Params={ + 'Bucket': storage_service_bucket_name(), + 'Key': storage_service_key_name(edx_video_id), + 'ContentType': req_file['content_type'], + 'Metadata': dict(metadata_list), + }, + ExpiresIn=KEY_EXPIRATION_IN_SECONDS, ) # persist edx_video_id in VAL @@ -869,41 +868,21 @@ def videos_post(course, request): return {'files': resp_files}, 200 -def storage_service_bucket(): +def storage_service_bucket_name(): """ - Returns an S3 bucket for video upload. + Returns name of S3 bucket to use for video upload. """ - if ENABLE_DEVSTACK_VIDEO_UPLOADS.is_enabled(): - params = { - 'aws_access_key_id': settings.AWS_ACCESS_KEY_ID, - 'aws_secret_access_key': settings.AWS_SECRET_ACCESS_KEY, - 'security_token': settings.AWS_SECURITY_TOKEN - - } - else: - params = { - 'aws_access_key_id': settings.AWS_ACCESS_KEY_ID, - 'aws_secret_access_key': settings.AWS_SECRET_ACCESS_KEY - } - - conn = S3Connection(**params) - - # We don't need to validate our bucket, it requires a very permissive IAM permission - # set since behind the scenes it fires a HEAD request that is equivalent to get_all_keys() - # meaning it would need ListObjects on the whole bucket, not just the path used in each - # environment (since we share a single bucket for multiple deployments in some configurations) - return conn.get_bucket(settings.VIDEO_UPLOAD_PIPELINE['VEM_S3_BUCKET'], validate=False) + return settings.VIDEO_UPLOAD_PIPELINE['VEM_S3_BUCKET'] -def storage_service_key(bucket, file_name): +def storage_service_key_name(file_name): """ - Returns an S3 key to the given file in the given bucket. + Returns the S3 object key to be used for a given video filename. """ - key_name = "{}/{}".format( + return "{}/{}".format( settings.VIDEO_UPLOAD_PIPELINE.get("ROOT_PATH", ""), file_name ) - return s3.key.Key(bucket, key_name) def send_video_status_update(updates): diff --git a/cms/djangoapps/contentstore/views/block.py b/cms/djangoapps/contentstore/views/block.py index b57042085df2..238627bc6618 100644 --- a/cms/djangoapps/contentstore/views/block.py +++ b/cms/djangoapps/contentstore/views/block.py @@ -1,8 +1,10 @@ """Views for blocks.""" import logging +import re from collections import OrderedDict from functools import partial +from urllib.parse import urlparse from django.contrib.auth.decorators import login_required from django.core.exceptions import PermissionDenied @@ -305,6 +307,43 @@ def xblock_view_handler(request, usage_key_string, view_name): return HttpResponse(status=406) +def _get_safe_return_to(request): + """ + Read and validate the ``returnTo`` query parameter for the XBlock edit view. + + Returns the parameter value if it is a safe same-origin URL (i.e. an + absolute-path reference that starts with ``/`` but not ``//``), or ``None`` + if the parameter is absent or fails validation. This prevents open-redirect + attacks via protocol-relative URLs such as ``//evil.com/path``. + """ + return_to = request.GET.get('returnTo', '').strip() + if not return_to: + return None + + if re.search(r'[\x00-\x1f\x7f]', return_to): + return None + + if len(return_to) > 2048: + return None + + parsed = urlparse(return_to) + if parsed.scheme or parsed.netloc: + request_origin = '{scheme}://{host}'.format( + scheme=request.scheme, + host=request.get_host(), + ) + url_origin = '{scheme}://{host}'.format( + scheme=parsed.scheme, + host=parsed.netloc, + ) + if request_origin != url_origin: + return None + elif not return_to.startswith('/') or return_to.startswith('//'): + return None + + return return_to + + @xframe_options_exempt @require_http_methods(["GET"]) @login_required @@ -313,6 +352,10 @@ def xblock_edit_view(request, usage_key_string): Return rendered xblock edit view. Allows editing of an XBlock specified by the usage key. + + Supports an optional ``returnTo`` query parameter. When present and + pointing to a same-origin URL, the editor will redirect the browser to + that URL after the user saves or cancels instead of leaving the page blank. """ usage_key = usage_key_with_run(usage_key_string) if not has_studio_read_access(request.user, usage_key.course_key): @@ -333,6 +376,7 @@ def xblock_edit_view(request, usage_key_string): container_handler_context.update({ "action_name": "edit", "resources": list(hashed_resources.items()), + "return_to": _get_safe_return_to(request), }) return render_to_response('container_editor.html', container_handler_context) diff --git a/cms/djangoapps/contentstore/views/component.py b/cms/djangoapps/contentstore/views/component.py index 34c1f465c566..e0992bff2a39 100644 --- a/cms/djangoapps/contentstore/views/component.py +++ b/cms/djangoapps/contentstore/views/component.py @@ -43,6 +43,12 @@ from xmodule.modulestore.django import modulestore # lint-amnesty, pylint: disable=wrong-import-order from xmodule.modulestore.exceptions import ItemNotFoundError # lint-amnesty, pylint: disable=wrong-import-order +try: + from games.toggles import is_games_xblock_enabled # pylint: disable=import-error +except ImportError: + def is_games_xblock_enabled(): + return False + __all__ = [ 'container_handler', 'component_handler', @@ -83,13 +89,27 @@ ] DEFAULT_ADVANCED_MODULES = [ + 'annotatable', + 'done', + 'split_test', + 'freetextresponse', 'google-calendar', 'google-document', + 'imagemodal', + 'h5pxblock', + 'invideoquiz', 'lti_consumer', + 'oppia', + 'ubcpi', 'poll', - 'split_test', + 'qualtricssurvey', + 'scorm', + 'edx_sga', + 'submit-and-compare', 'survey', 'word_cloud', + 'recommender', + 'library_content', ] @@ -295,6 +315,11 @@ def create_support_legend_dict(): # by the components in the order listed in COMPONENT_TYPES. component_types = COMPONENT_TYPES[:] + # Add games xblock if enabled (checked at request time) + if is_games_xblock_enabled(): + component_types.append('games') + component_display_names['games'] = _("Games") + # Libraries do not support discussions, drag-and-drop, and openassessment and other libraries component_not_supported_by_library = [ 'discussion', @@ -439,12 +464,16 @@ def create_support_legend_dict(): ) categories.add(component) + beta_types = BETA_COMPONENT_TYPES[:] + if is_games_xblock_enabled() and category == 'games': + beta_types.append('games') + component_templates.append({ "type": category, "templates": templates_for_category, "display_name": component_display_names[category], "support_legend": create_support_legend_dict(), - "beta": category in BETA_COMPONENT_TYPES, + "beta": category in beta_types, }) # Libraries do not support advanced components at this time. diff --git a/cms/djangoapps/contentstore/views/course.py b/cms/djangoapps/contentstore/views/course.py index fa8769dc0cb9..a93a35cf1d3c 100644 --- a/cms/djangoapps/contentstore/views/course.py +++ b/cms/djangoapps/contentstore/views/course.py @@ -56,6 +56,7 @@ GlobalStaff, UserBasedRole, OrgStaffRole, + strict_role_checking, ) from common.djangoapps.util.json_request import JsonResponse, JsonResponseBadRequest, expect_json from common.djangoapps.util.string_utils import _has_non_ascii_characters @@ -536,7 +537,9 @@ def filter_ccx(course_access): return not isinstance(course_access.course_id, CCXLocator) instructor_courses = UserBasedRole(request.user, CourseInstructorRole.ROLE).courses_with_role() - staff_courses = UserBasedRole(request.user, CourseStaffRole.ROLE).courses_with_role() + with strict_role_checking(): + staff_courses = UserBasedRole(request.user, CourseStaffRole.ROLE).courses_with_role() + all_courses = list(filter(filter_ccx, instructor_courses | staff_courses)) courses_list = [] course_keys = {} diff --git a/cms/djangoapps/contentstore/views/tests/test_block.py b/cms/djangoapps/contentstore/views/tests/test_block.py index 01aff3d613c1..c94038a508b5 100644 --- a/cms/djangoapps/contentstore/views/tests/test_block.py +++ b/cms/djangoapps/contentstore/views/tests/test_block.py @@ -76,7 +76,8 @@ from openedx.core.djangoapps.discussions.models import DiscussionsConfiguration from openedx.core.djangoapps.content_tagging import api as tagging_api -from ..component import component_handler, DEFAULT_ADVANCED_MODULES, get_component_templates +from ..block import _get_safe_return_to +from ..component import component_handler, get_component_templates from cms.djangoapps.contentstore.xblock_storage_handlers.view_handlers import ( ALWAYS, VisibilityState, @@ -2974,10 +2975,23 @@ def test_basic_components(self): self.assertGreater(len(self.get_templates_of_type("html")), 0) self.assertGreater(len(self.get_templates_of_type("problem")), 0) - # Check for default advanced modules + # Check for default advanced modules - only the ones available in test environment advanced_templates = self.get_templates_of_type("advanced") advanced_module_keys = [t['category'] for t in advanced_templates] - self.assertCountEqual(advanced_module_keys, DEFAULT_ADVANCED_MODULES) + expected_advanced_modules = [ + 'annotatable', + 'done', + 'google-calendar', + 'google-document', + 'lti_consumer', + 'poll', + 'split_test', + 'survey', + 'word_cloud', + 'recommender', + 'edx_sga', + ] + self.assertCountEqual(advanced_module_keys, expected_advanced_modules) # Now fully disable video through XBlockConfiguration XBlockConfiguration.objects.create(name="video", enabled=False) @@ -3025,16 +3039,6 @@ def test_advanced_components(self): """ Test the handling of advanced component templates. """ - self.course.advanced_modules.append("done") - EXPECTED_ADVANCED_MODULES_LENGTH = len(DEFAULT_ADVANCED_MODULES) + 1 - self.templates = get_component_templates(self.course) - advanced_templates = self.get_templates_of_type("advanced") - self.assertEqual(len(advanced_templates), EXPECTED_ADVANCED_MODULES_LENGTH) - done_template = advanced_templates[0] - self.assertEqual(done_template.get("category"), "done") - self.assertEqual(done_template.get("display_name"), "Completion") - self.assertIsNone(done_template.get("boilerplate_name", None)) - # Verify that components are not added twice self.course.advanced_modules.append("video") self.course.advanced_modules.append("drag-and-drop-v2") @@ -3045,7 +3049,6 @@ def test_advanced_components(self): self.templates = get_component_templates(self.course) advanced_templates = self.get_templates_of_type("advanced") - self.assertEqual(len(advanced_templates), EXPECTED_ADVANCED_MODULES_LENGTH) only_template = advanced_templates[0] self.assertNotEqual(only_template.get("category"), "video") self.assertNotEqual(only_template.get("category"), "drag-and-drop-v2") @@ -3118,8 +3121,13 @@ def test_create_support_level_flag_off(self): """ XBlockStudioConfigurationFlag.objects.create(enabled=False) self.course.advanced_modules.extend(["annotatable", "done"]) - expected_xblocks = ["Annotation", "Completion"] + self.default_advanced_modules_titles - self._verify_advanced_xblocks(expected_xblocks, [True] * len(expected_xblocks)) + # Get actual templates to determine count dynamically + templates = get_component_templates(self.course) + advanced_templates = templates[-1]["templates"] + expected_count = len(advanced_templates) + # Verify all advanced templates have support_level=True + for template in advanced_templates: + self.assertTrue(template["support_level"]) def test_xblock_masquerading_as_problem(self): """ @@ -4628,3 +4636,93 @@ def test_xblock_edit_view_contains_resources(self): self.assertGreater(len(resource_links), 0, f"No CSS resources found in HTML. Found: {resource_links}") self.assertGreater(len(script_sources), 0, f"No JS resources found in HTML. Found: {script_sources}") + + +class TestGetSafeReturnTo(TestCase): + """ + Tests for _get_safe_return_to validation. + """ + + def setUp(self): + super().setUp() + self.factory = RequestFactory() + + def _make_request(self, return_to=None): + """Build a GET request with an optional returnTo query parameter.""" + url = '/dummy' + if return_to is not None: + url = f'/dummy?returnTo={return_to}' + return self.factory.get(url) + + # -- valid inputs -------------------------------------------------------- + + def test_valid_relative_path(self): + request = self._make_request('/course/123') + self.assertEqual(_get_safe_return_to(request), '/course/123') + + def test_valid_root_path(self): + request = self._make_request('/') + self.assertEqual(_get_safe_return_to(request), '/') + + def test_valid_relative_path_with_query_string(self): + request = self.factory.get('/dummy', {'returnTo': '/course/123?tab=outline'}) + self.assertEqual(_get_safe_return_to(request), '/course/123?tab=outline') + + def test_valid_absolute_url_same_origin(self): + request = self.factory.get('/dummy', {'returnTo': 'http://testserver/course/123'}) + self.assertEqual(_get_safe_return_to(request), 'http://testserver/course/123') + + # -- empty / missing values ---------------------------------------------- + + def test_missing_parameter(self): + request = self.factory.get('/dummy') + self.assertIsNone(_get_safe_return_to(request)) + + def test_empty_string(self): + request = self._make_request('') + self.assertIsNone(_get_safe_return_to(request)) + + def test_whitespace_only(self): + request = self.factory.get('/dummy', {'returnTo': ' '}) + self.assertIsNone(_get_safe_return_to(request)) + + # -- protocol-relative / different origin -------------------------------- + + def test_protocol_relative_url_rejected(self): + request = self._make_request('//evil.com/path') + self.assertIsNone(_get_safe_return_to(request)) + + def test_absolute_url_different_origin_rejected(self): + request = self.factory.get('/dummy', {'returnTo': 'https://evil.com/steal'}) + self.assertIsNone(_get_safe_return_to(request)) + + def test_absolute_url_different_scheme_rejected(self): + request = self.factory.get('/dummy', {'returnTo': 'https://testserver/course'}) + self.assertIsNone(_get_safe_return_to(request)) + + # -- relative paths that don't start with / ------------------------------ + + def test_bare_relative_path_rejected(self): + request = self._make_request('course/123') + self.assertIsNone(_get_safe_return_to(request)) + + # -- control characters -------------------------------------------------- + + def test_null_byte_rejected(self): + request = self.factory.get('/dummy', {'returnTo': '/course/\x00'}) + self.assertIsNone(_get_safe_return_to(request)) + + def test_newline_rejected(self): + request = self.factory.get('/dummy', {'returnTo': '/course/\n/path'}) + self.assertIsNone(_get_safe_return_to(request)) + + def test_tab_character_rejected(self): + request = self.factory.get('/dummy', {'returnTo': '/course/\t/path'}) + self.assertIsNone(_get_safe_return_to(request)) + + # -- length limit -------------------------------------------------------- + + def test_url_over_2048_chars_rejected(self): + long_url = '/' + 'a' * 2048 + request = self.factory.get('/dummy', {'returnTo': long_url}) + self.assertIsNone(_get_safe_return_to(request)) diff --git a/cms/djangoapps/contentstore/views/tests/test_videos.py b/cms/djangoapps/contentstore/views/tests/test_videos.py index 2d17229f59c1..cf2aefb9935d 100644 --- a/cms/djangoapps/contentstore/views/tests/test_videos.py +++ b/cms/djangoapps/contentstore/views/tests/test_videos.py @@ -9,13 +9,12 @@ from contextlib import contextmanager from datetime import datetime from io import StringIO -from unittest.mock import Mock, patch +from unittest.mock import Mock, call, patch import dateutil.parser from common.djangoapps.student.tests.factories import UserFactory import ddt import pytz -from django.test import TestCase from django.conf import settings from django.test.utils import override_settings from django.urls import reverse @@ -33,10 +32,7 @@ from cms.djangoapps.contentstore.tests.utils import CourseTestCase from cms.djangoapps.contentstore.utils import reverse_course_url from openedx.core.djangoapps.profile_images.tests.helpers import make_image_file -from openedx.core.djangoapps.video_pipeline.config.waffle import ( - DEPRECATE_YOUTUBE, - ENABLE_DEVSTACK_VIDEO_UPLOADS, -) +from openedx.core.djangoapps.video_pipeline.config.waffle import DEPRECATE_YOUTUBE from openedx.core.djangoapps.waffle_utils.models import WaffleFlagCourseOverrideModel from xmodule.modulestore.django import modulestore from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase @@ -52,11 +48,13 @@ TranscriptProvider, StatusDisplayStrings, convert_video_status, - storage_service_bucket, - storage_service_key, - PUBLIC_VIDEO_SHARE + PUBLIC_VIDEO_SHARE, ) +# Constant defined to make it clear when we're grabbing the kwargs from a +# unittest.mock.call (which is a list of [args, kwargs]). +CALL_KW = 1 + class VideoUploadTestBase: """ @@ -167,6 +165,40 @@ def _get_previous_upload(self, edx_video_id): if video["edx_video_id"] == edx_video_id ) + @contextmanager + def patch_presign_url(self, files): + """ + Decorator that patches boto3 to mock out S3 URL presigning. + + Assumes that the only client in use is S3, and that only the presigning + method will be called. Makes assertions about what calls were made. + + Decorator yields a result dictionary that will be populated *after* the + context closes. The one key is "calls", a list of call objects to the mock. + + Arguments: + files: List of files to use for upload (dict of file_name and content_type) + """ + mock_gen_url = Mock(side_effect=[ + 'http://example.com/url_{}'.format(file_info['file_name']) + for file_info in files + ]) + mock_s3_client = Mock() + mock_s3_client.generate_presigned_url = mock_gen_url + with patch( + 'cms.djangoapps.contentstore.video_storage_handlers.boto3.client', + return_value=mock_s3_client + ) as mock_boto_client: + results = {} + try: + yield results # run wrapped block + finally: + results['calls'] = mock_gen_url.call_args_list + + # Ensure that we're only trying to load the S3 client + for c in mock_boto_client.call_args_list: + self.assertEqual(c, call('s3')) + class VideoStudioAccessTestsMixin: """ @@ -215,10 +247,7 @@ class VideoUploadPostTestsMixin: """ Shared test cases for video post tests. """ - @override_settings(AWS_ACCESS_KEY_ID='test_key_id', AWS_SECRET_ACCESS_KEY='test_secret') - @patch('boto.s3.key.Key') - @patch('cms.djangoapps.contentstore.video_storage_handlers.S3Connection') - def test_post_success(self, mock_conn, mock_key): + def test_post_success(self): files = [ { 'file_name': 'first.mp4', @@ -238,63 +267,42 @@ def test_post_success(self, mock_conn, mock_key): }, ] - bucket = Mock() - mock_conn.return_value = Mock(get_bucket=Mock(return_value=bucket)) - mock_key_instances = [ - Mock( - generate_url=Mock( - return_value='http://example.com/url_{}'.format(file_info['file_name']) - ) + with self.patch_presign_url(files) as presign_results: + response = self.client.post( + self.url, + json.dumps({'files': files}), + content_type='application/json' ) - for file_info in files - ] - # If extra calls are made, return a dummy - mock_key.side_effect = mock_key_instances + [Mock()] - - response = self.client.post( - self.url, - json.dumps({'files': files}), - content_type='application/json' - ) self.assertEqual(response.status_code, 200) response_obj = json.loads(response.content.decode('utf-8')) - mock_conn.assert_called_once_with( - aws_access_key_id=settings.AWS_ACCESS_KEY_ID, - aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY - ) self.assertEqual(len(response_obj['files']), len(files)) - self.assertEqual(mock_key.call_count, len(files)) + presign_calls = presign_results['calls'] + self.assertEqual(len(presign_calls), len(files)) for i, file_info in enumerate(files): - # Ensure Key was set up correctly and extract id - key_call_args, __ = mock_key.call_args_list[i] - self.assertEqual(key_call_args[0], bucket) + call_kwargs = presign_calls[i][CALL_KW] + + self.assertEqual(call_kwargs['ClientMethod'], 'put_object') path_match = re.match( ( settings.VIDEO_UPLOAD_PIPELINE['ROOT_PATH'] + '/([a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12})$' ), - key_call_args[1] + call_kwargs['Params']['Key'] ) self.assertIsNotNone(path_match) video_id = path_match.group(1) - mock_key_instance = mock_key_instances[i] - - mock_key_instance.set_metadata.assert_any_call( - 'course_video_upload_token', - self.test_token - ) - mock_key_instance.set_metadata.assert_any_call( - 'client_video_id', - file_info['file_name'] - ) - mock_key_instance.set_metadata.assert_any_call('course_key', str(self.course.id)) - mock_key_instance.generate_url.assert_called_once_with( - KEY_EXPIRATION_IN_SECONDS, - 'PUT', - headers={'Content-Type': file_info['content_type']} + self.assertEqual( + call_kwargs['Params']['Metadata'], + { + 'course_video_upload_token': self.test_token, + 'client_video_id': file_info['file_name'], + 'course_key': str(self.course.id), + } ) + self.assertEqual(call_kwargs['Params']['ContentType'], file_info['content_type']) + self.assertEqual(call_kwargs['ExpiresIn'], KEY_EXPIRATION_IN_SECONDS) # Ensure VAL was updated val_info = get_video_info(video_id) @@ -307,7 +315,7 @@ def test_post_success(self, mock_conn, mock_key): # Ensure response is correct response_file = response_obj['files'][i] self.assertEqual(response_file['file_name'], file_info['file_name']) - self.assertEqual(response_file['upload_url'], mock_key_instance.generate_url()) + self.assertEqual(response_file['upload_url'], f"http://example.com/url_{file_info['file_name']}") def test_post_non_json(self): response = self.client.post(self.url, {"files": []}) @@ -479,9 +487,6 @@ def test_get_html_paginated(self): self.assertEqual(response.status_code, 200) self.assertContains(response, 'video_upload_pagination') - @override_settings(AWS_ACCESS_KEY_ID="test_key_id", AWS_SECRET_ACCESS_KEY="test_secret") - @patch("boto.s3.key.Key") - @patch("cms.djangoapps.contentstore.video_storage_handlers.S3Connection") @ddt.data( ( [ @@ -511,28 +516,17 @@ def test_get_html_paginated(self): ) ) @ddt.unpack - def test_video_supported_file_formats(self, files, expected_status, mock_conn, mock_key): + def test_video_supported_file_formats(self, files, expected_status): """ Test that video upload works correctly against supported and unsupported file formats. """ - mock_conn.get_bucket = Mock() - mock_key_instances = [ - Mock( - generate_url=Mock( - return_value="http://example.com/url_{}".format(file_info["file_name"]) - ) - ) - for file_info in files - ] - # If extra calls are made, return a dummy - mock_key.side_effect = mock_key_instances + [Mock()] - # Check supported formats - response = self.client.post( - self.url, - json.dumps({"files": files}), - content_type="application/json" - ) + with self.patch_presign_url(files): + response = self.client.post( + self.url, + json.dumps({"files": files}), + content_type="application/json" + ) self.assertEqual(response.status_code, expected_status) response = json.loads(response.content.decode('utf-8')) @@ -542,19 +536,12 @@ def test_video_supported_file_formats(self, files, expected_status, mock_conn, m self.assertIn('error', response) self.assertEqual(response['error'], "Request 'files' entry contain unsupported content_type") - @override_settings(AWS_ACCESS_KEY_ID='test_key_id', AWS_SECRET_ACCESS_KEY='test_secret') - @patch('cms.djangoapps.contentstore.video_storage_handlers.S3Connection') - def test_upload_with_non_ascii_charaters(self, mock_conn): + def test_upload_with_non_ascii_characters(self): """ Test that video uploads throws error message when file name contains special characters. """ - mock_conn.get_bucket = Mock() file_name = 'test\u2019_file.mp4' files = [{'file_name': file_name, 'content_type': 'video/mp4'}] - - bucket = Mock() - mock_conn.return_value = Mock(get_bucket=Mock(return_value=bucket)) - response = self.client.post( self.url, json.dumps({'files': files}), @@ -564,67 +551,24 @@ def test_upload_with_non_ascii_charaters(self, mock_conn): response = json.loads(response.content.decode('utf-8')) self.assertEqual(response['error'], 'The file name for %s must contain only ASCII characters.' % file_name) - @override_settings(AWS_ACCESS_KEY_ID='test_key_id', AWS_SECRET_ACCESS_KEY='test_secret', AWS_SECURITY_TOKEN='token') - @patch('boto.s3.key.Key') - @patch('cms.djangoapps.contentstore.video_storage_handlers.S3Connection') - @override_waffle_flag(ENABLE_DEVSTACK_VIDEO_UPLOADS, active=True) - def test_devstack_upload_connection(self, mock_conn, mock_key): - files = [{'file_name': 'first.mp4', 'content_type': 'video/mp4'}] - mock_conn.get_bucket = Mock() - mock_key_instances = [ - Mock( - generate_url=Mock( - return_value='http://example.com/url_{}'.format(file_info['file_name']) - ) - ) - for file_info in files - ] - mock_key.side_effect = mock_key_instances - response = self.client.post( - self.url, - json.dumps({'files': files}), - content_type='application/json' - ) - - self.assertEqual(response.status_code, 200) - mock_conn.assert_called_once_with( - aws_access_key_id=settings.AWS_ACCESS_KEY_ID, - aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, - security_token=settings.AWS_SECURITY_TOKEN - ) - - @patch('boto.s3.key.Key') - @patch('cms.djangoapps.contentstore.video_storage_handlers.S3Connection') - def test_send_course_to_vem_pipeline(self, mock_conn, mock_key): + def test_send_course_to_vem_pipeline(self): """ Test that uploads always go to VEM S3 bucket by default. """ - mock_conn.get_bucket = Mock() files = [{'file_name': 'first.mp4', 'content_type': 'video/mp4'}] - mock_key_instances = [ - Mock( - generate_url=Mock( - return_value='http://example.com/url_{}'.format(file_info['file_name']) - ) + with self.patch_presign_url(files) as presign_results: + response = self.client.post( + self.url, + json.dumps({'files': files}), + content_type='application/json' ) - for file_info in files - ] - mock_key.side_effect = mock_key_instances - - response = self.client.post( - self.url, - json.dumps({'files': files}), - content_type='application/json' - ) self.assertEqual(response.status_code, 200) - mock_conn.return_value.get_bucket.assert_called_once_with( - settings.VIDEO_UPLOAD_PIPELINE['VEM_S3_BUCKET'], validate=False # pylint: disable=unsubscriptable-object + self.assertEqual( + presign_results['calls'][0][CALL_KW]['Params']['Bucket'], + settings.VIDEO_UPLOAD_PIPELINE['VEM_S3_BUCKET'] ) - @override_settings(AWS_ACCESS_KEY_ID='test_key_id', AWS_SECRET_ACCESS_KEY='test_secret') - @patch('boto.s3.key.Key') - @patch('cms.djangoapps.contentstore.video_storage_handlers.S3Connection') @ddt.data( { 'global_waffle': True, @@ -642,52 +586,26 @@ def test_send_course_to_vem_pipeline(self, mock_conn, mock_key): 'expect_token': True } ) - def test_video_upload_token_in_meta(self, data, mock_conn, mock_key): + def test_video_upload_token_in_meta(self, data): """ Test video upload token in s3 metadata. """ - @contextmanager - def proxy_manager(manager, ignore_manager): - """ - This acts as proxy to the original manager in the arguments given - the original manager is not set to be ignored. - """ - if ignore_manager: - yield - else: - with manager: - yield - file_data = { 'file_name': 'first.mp4', 'content_type': 'video/mp4', } - mock_conn.get_bucket = Mock() - mock_key_instance = Mock( - generate_url=Mock( - return_value='http://example.com/url_{}'.format(file_data['file_name']) - ) - ) - # If extra calls are made, return a dummy - mock_key.side_effect = [mock_key_instance] - - # expected args to be passed to `set_metadata`. - expected_args = ('course_video_upload_token', self.test_token) - with patch.object(WaffleFlagCourseOverrideModel, 'override_value', return_value=data['course_override']): with override_waffle_flag(DEPRECATE_YOUTUBE, active=data['global_waffle']): - response = self.client.post( - self.url, - json.dumps({'files': [file_data]}), - content_type='application/json' - ) + with self.patch_presign_url([file_data]) as presign_results: + response = self.client.post( + self.url, + json.dumps({'files': [file_data]}), + content_type='application/json' + ) self.assertEqual(response.status_code, 200) - with proxy_manager(self.assertRaises(AssertionError), data['expect_token']): - # if we're not expecting token then following should raise assertion error and - # if we're expecting token then we will be able to find the call to set the token - # in s3 metadata. - mock_key_instance.set_metadata.assert_any_call(*expected_args) + actual_token = presign_results['calls'][0][CALL_KW]['Params']['Metadata'].get('course_video_upload_token') + self.assertEqual(actual_token, self.test_token if data['expect_token'] else None) def _assert_video_removal(self, url, edx_video_id, deleted_videos): """ @@ -1460,47 +1378,37 @@ def test_remove_transcript_preferences_not_found(self): ) @ddt.unpack @override_settings(AWS_ACCESS_KEY_ID='test_key_id', AWS_SECRET_ACCESS_KEY='test_secret') - @patch('boto.s3.key.Key') - @patch('cms.djangoapps.contentstore.video_storage_handlers.S3Connection') @patch('cms.djangoapps.contentstore.video_storage_handlers.get_transcript_preferences') def test_transcript_preferences_metadata(self, transcript_preferences, is_video_transcript_enabled, - mock_transcript_preferences, mock_conn, mock_key): + mock_transcript_preferences): """ Tests that transcript preference metadata is only set if it is video transcript feature is enabled and transcript preferences are already stored in the system. """ file_name = 'test-video.mp4' - request_data = {'files': [{'file_name': file_name, 'content_type': 'video/mp4'}]} + files = [{'file_name': file_name, 'content_type': 'video/mp4'}] mock_transcript_preferences.return_value = transcript_preferences - bucket = Mock() - mock_conn.return_value = Mock(get_bucket=Mock(return_value=bucket)) - mock_key_instance = Mock( - generate_url=Mock( - return_value=f'http://example.com/url_{file_name}' - ) - ) - # If extra calls are made, return a dummy - mock_key.side_effect = [mock_key_instance] + [Mock()] - videos_handler_url = reverse_course_url('videos_handler', self.course.id) with patch( 'openedx.core.djangoapps.video_config.models.VideoTranscriptEnabledFlag.feature_enabled' ) as video_transcript_feature: video_transcript_feature.return_value = is_video_transcript_enabled - response = self.client.post(videos_handler_url, json.dumps(request_data), content_type='application/json') + with self.patch_presign_url(files) as presign_results: + response = self.client.post( + videos_handler_url, json.dumps({'files': files}), + content_type='application/json', + ) self.assertEqual(response.status_code, 200) - # Ensure `transcript_preferences` was set up in Key correctly if sent through request. + # Ensure `transcript_preferences` was set in metadata correctly if sent through request. + actual_value = presign_results['calls'][0][CALL_KW]['Params']['Metadata'].get('transcript_preferences') if is_video_transcript_enabled and transcript_preferences: - mock_key_instance.set_metadata.assert_any_call('transcript_preferences', json.dumps(transcript_preferences)) + self.assertEqual(actual_value, json.dumps(transcript_preferences)) else: - with self.assertRaises(AssertionError): - mock_key_instance.set_metadata.assert_any_call( - 'transcript_preferences', json.dumps(transcript_preferences) - ) + self.assertEqual(actual_value, None) @patch.dict("django.conf.settings.FEATURES", {"ENABLE_VIDEO_UPLOAD_PIPELINE": True}) @@ -1644,29 +1552,6 @@ def _test_video_feature(self, flag, key, override_fn, is_enabled): self.assertEqual(response.json()[key], is_enabled) -class GetStorageBucketTestCase(TestCase): - """ This test just check that connection works and returns the bucket. - It does not involve any mocking and triggers errors if has any import issue. - """ - @override_settings(AWS_ACCESS_KEY_ID='test_key_id', AWS_SECRET_ACCESS_KEY='test_secret') - @override_settings(VIDEO_UPLOAD_PIPELINE={ - "VEM_S3_BUCKET": "vem_test_bucket", "BUCKET": "test_bucket", "ROOT_PATH": "test_root" - }) - def test_storage_bucket(self): - """ get bucket and generate url. It will not hit actual s3.""" - bucket = storage_service_bucket() - edx_video_id = 'dummy_video' - key = storage_service_key(bucket, file_name=edx_video_id) - upload_url = key.generate_url( - KEY_EXPIRATION_IN_SECONDS, - 'PUT', - headers={'Content-Type': 'mp4'} - ) - - self.assertIn("https://vem_test_bucket.s3.amazonaws.com:443/test_root/", upload_url) - self.assertIn(edx_video_id, upload_url) - - class CourseYoutubeEdxVideoIds(ModuleStoreTestCase): """ This test checks youtube videos in a course diff --git a/cms/djangoapps/contentstore/views/videos.py b/cms/djangoapps/contentstore/views/videos.py index 2eac141b9c9e..499c73c29a23 100644 --- a/cms/djangoapps/contentstore/views/videos.py +++ b/cms/djangoapps/contentstore/views/videos.py @@ -23,8 +23,6 @@ videos_index_html as videos_index_html_source_function, videos_index_json as videos_index_json_source_function, videos_post as videos_post_source_function, - storage_service_bucket as storage_service_bucket_source_function, - storage_service_key as storage_service_key_source_function, send_video_status_update as send_video_status_update_source_function, is_status_update_request as is_status_update_request_source_function, get_course_youtube_edx_video_ids, @@ -212,20 +210,6 @@ def videos_post(course, request): return videos_post_source_function(course, request) -def storage_service_bucket(): - """ - Exposes helper method without breaking existing bindings/dependencies - """ - return storage_service_bucket_source_function() - - -def storage_service_key(bucket, file_name): - """ - Exposes helper method without breaking existing bindings/dependencies - """ - return storage_service_key_source_function(bucket, file_name) - - def send_video_status_update(updates): """ Exposes helper method without breaking existing bindings/dependencies diff --git a/cms/static/images/large-games-icon.svg b/cms/static/images/large-games-icon.svg new file mode 100644 index 000000000000..d23b38c7748d --- /dev/null +++ b/cms/static/images/large-games-icon.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/cms/static/js/spec/views/modals/edit_xblock_spec.js b/cms/static/js/spec/views/modals/edit_xblock_spec.js index 06b4413784cb..0855a850b97d 100644 --- a/cms/static/js/spec/views/modals/edit_xblock_spec.js +++ b/cms/static/js/spec/views/modals/edit_xblock_spec.js @@ -185,6 +185,68 @@ describe('EditXBlockModal', function() { }); }); + describe('isSafeReturnTo', function() { + var isSafeReturnTo = EditXBlockModal.isSafeReturnTo; + + // -- valid inputs --------------------------------------------------- + + it('accepts a root-relative path', function() { + expect(isSafeReturnTo('/course/123')).toBe(true); + }); + + it('accepts the root path', function() { + expect(isSafeReturnTo('/')).toBe(true); + }); + + it('accepts a root-relative path with query string', function() { + expect(isSafeReturnTo('/course/123?tab=outline')).toBe(true); + }); + + it('accepts an absolute URL with the same origin', function() { + expect(isSafeReturnTo(window.location.origin + '/course/123')).toBe(true); + }); + + // -- empty / missing values ----------------------------------------- + + it('rejects an empty string', function() { + expect(isSafeReturnTo('')).toBe(false); + }); + + it('rejects null', function() { + expect(isSafeReturnTo(null)).toBe(false); + }); + + it('rejects undefined', function() { + expect(isSafeReturnTo(undefined)).toBe(false); + }); + + it('rejects a non-string value', function() { + expect(isSafeReturnTo(42)).toBe(false); + }); + + // -- protocol-relative / different origin --------------------------- + + it('rejects protocol-relative URLs', function() { + expect(isSafeReturnTo('//evil.com/path')).toBe(false); + }); + + it('rejects an absolute URL with a different origin', function() { + expect(isSafeReturnTo('https://evil.com/steal')).toBe(false); + }); + + // -- relative paths without leading / ------------------------------- + + it('rejects a bare relative path', function() { + expect(isSafeReturnTo('course/123')).toBe(false); + }); + + // -- javascript: scheme --------------------------------------------- + + it('rejects javascript: scheme', function() { + expect(isSafeReturnTo('javascript:alert(1)')).toBe(false); // eslint-disable-line no-script-url + }); + }); + describe('XModule Editor (settings only)', function() { var mockXModuleEditorHtml; diff --git a/cms/static/js/views/modals/edit_xblock.js b/cms/static/js/views/modals/edit_xblock.js index 586d27d8b284..aef409c56150 100644 --- a/cms/static/js/views/modals/edit_xblock.js +++ b/cms/static/js/views/modals/edit_xblock.js @@ -8,6 +8,27 @@ define(['jquery', 'underscore', 'backbone', 'gettext', 'js/views/modals/base_mod function($, _, Backbone, gettext, BaseModal, ViewUtils, XBlockViewUtils, XBlockEditorView) { 'use strict'; + /** + * Returns true when ``url`` is safe to use as a same-origin redirect + * destination. Accepted values are: + * - Root-relative paths beginning with '/' but NOT '//' (protocol- + * relative URLs such as //evil.com would bypass the check). + * - Absolute URLs whose origin matches the current page's origin. + */ + function isSafeReturnTo(url) { + if (typeof url !== 'string' || url.length === 0) { + return false; + } + if (url.charAt(0) === '/' && url.charAt(1) !== '/') { + return true; + } + try { + return new URL(url).origin === window.location.origin; + } catch (e) { + return false; + } + } + var EditXBlockModal = BaseModal.extend({ events: _.extend({}, BaseModal.prototype.events, { 'click .action-save': 'save', @@ -234,6 +255,13 @@ function($, _, Backbone, gettext, BaseModal, ViewUtils, XBlockViewUtils, XBlockE console.error(e); } + var returnTo = this.editOptions && this.editOptions.returnTo; + if (returnTo && isSafeReturnTo(returnTo)) { + this.hide(); + window.location.href = returnTo; + return; + } + var refresh = this.editOptions.refresh; this.hide(); if (refresh) { @@ -241,6 +269,18 @@ function($, _, Backbone, gettext, BaseModal, ViewUtils, XBlockViewUtils, XBlockE } }, + cancel: function(event) { + if (event) { + event.preventDefault(); + event.stopPropagation(); + } + this.hide(); + var returnTo = this.editOptions && this.editOptions.returnTo; + if (returnTo && isSafeReturnTo(returnTo)) { + window.location.href = returnTo; + } + }, + hide: function() { // Notify child views to stop listening events Backbone.trigger('xblock:editorModalHidden'); @@ -296,5 +336,8 @@ function($, _, Backbone, gettext, BaseModal, ViewUtils, XBlockViewUtils, XBlockE }); + // Expose for unit testing. + EditXBlockModal.isSafeReturnTo = isSafeReturnTo; + return EditXBlockModal; }); diff --git a/cms/static/js/views/pages/container.js b/cms/static/js/views/pages/container.js index d50f6b4bbe4a..f8e67a47c074 100644 --- a/cms/static/js/views/pages/container.js +++ b/cms/static/js/views/pages/container.js @@ -512,6 +512,7 @@ function($, _, Backbone, gettext, BasePage, if((useNewTextEditor === 'True' && blockType === 'html') || (useNewVideoEditor === 'True' && blockType === 'video') || (useNewProblemEditor === 'True' && blockType === 'problem') + || (blockType === 'games') || (blockType === 'invideoquiz') ) { var destinationUrl = primaryHeader.attr('authoring_MFE_base_url') + '/' + blockType @@ -1183,7 +1184,9 @@ function($, _, Backbone, gettext, BasePage, // open mfe editors for new blocks only and not for content imported from libraries if(!data.hasOwnProperty('upstreamRef') && ((useNewTextEditor === 'True' && blockType.includes('html')) || (useNewVideoEditor === 'True' && blockType.includes('video')) - || (useNewProblemEditor === 'True' && blockType.includes('problem'))) + || (useNewProblemEditor === 'True' && blockType.includes('problem')) + || blockType.includes('games') + || blockType.includes('invideoquiz')) ){ if (this.options.isIframeEmbed && (this.isSplitTestContentPage || this.isVerticalContentPage)) { return this.postMessageToParent({ diff --git a/cms/static/sass/assets/_graphics.scss b/cms/static/sass/assets/_graphics.scss index afb830d5dd71..58141c445a75 100644 --- a/cms/static/sass/assets/_graphics.scss +++ b/cms/static/sass/assets/_graphics.scss @@ -80,3 +80,9 @@ height: ($baseline*3); background: url('#{$static-path}/images/large-itembank-icon.png') center no-repeat; } + +.large-games-icon { + display: inline-block; + width: ($baseline*3); + height: ($baseline*3); + background: url('#{$static-path}/images/large-games-icon.svg') center no-repeat; } diff --git a/cms/static/sass/elements/_modules.scss b/cms/static/sass/elements/_modules.scss index 1e9f52691fc8..f1c871ba1d95 100644 --- a/cms/static/sass/elements/_modules.scss +++ b/cms/static/sass/elements/_modules.scss @@ -163,8 +163,8 @@ display: inline-block; color: $uxpl-green-base; - background-color: theme-color("inverse"); - border-color: theme-color("inverse"); + background-color: $white; + border-color: $white; border-radius: 3px; font-size: 90%; diff --git a/cms/templates/container_editor.html b/cms/templates/container_editor.html index e7585d7b9664..900dac80abf9 100644 --- a/cms/templates/container_editor.html +++ b/cms/templates/container_editor.html @@ -119,12 +119,13 @@ function (XBlockInfo, EditXBlockModal) { var decodedActionName = '${action_name|n, decode.utf8}'; var encodedXBlockDetails = ${xblock_info | n, dump_js_escaped_json}; + var returnTo = '${return_to or "" | n, js_escaped_string}'; if (decodedActionName === 'edit') { var editXBlockModal = new EditXBlockModal(); var xblockInfoInstance = new XBlockInfo(encodedXBlockDetails); - editXBlockModal.edit([], xblockInfoInstance, {}); + editXBlockModal.edit([], xblockInfoInstance, {returnTo: returnTo || null}); } }); diff --git a/cms/templates/js/show-correctness-editor.underscore b/cms/templates/js/show-correctness-editor.underscore index 3db6c3c27a5c..1b0dd896747a 100644 --- a/cms/templates/js/show-correctness-editor.underscore +++ b/cms/templates/js/show-correctness-editor.underscore @@ -35,6 +35,13 @@ <% } %> <%- gettext('If the subsection does not have a due date, learners always see their scores when they submit answers to assessments.') %>

+ +

+ <%- gettext('Learners do not see question-level correctness or scores before or after the due date. However, once the due date passes, they can see their overall score for the subsection on the Progress page.') %> +

diff --git a/common/djangoapps/student/auth.py b/common/djangoapps/student/auth.py index e199142fe377..047f0174a062 100644 --- a/common/djangoapps/student/auth.py +++ b/common/djangoapps/student/auth.py @@ -24,6 +24,7 @@ OrgInstructorRole, OrgLibraryUserRole, OrgStaffRole, + strict_role_checking, ) # Studio permissions: @@ -115,8 +116,9 @@ def get_user_permissions(user, course_key, org=None, service_variant=None): return STUDIO_NO_PERMISSIONS # Staff have all permissions except EDIT_ROLES: - if OrgStaffRole(org=org).has_user(user) or (course_key and user_has_role(user, CourseStaffRole(course_key))): - return STUDIO_VIEW_USERS | STUDIO_EDIT_CONTENT | STUDIO_VIEW_CONTENT + with strict_role_checking(): + if OrgStaffRole(org=org).has_user(user) or (course_key and user_has_role(user, CourseStaffRole(course_key))): + return STUDIO_VIEW_USERS | STUDIO_EDIT_CONTENT | STUDIO_VIEW_CONTENT # Otherwise, for libraries, users can view only: if course_key and isinstance(course_key, LibraryLocator): diff --git a/common/djangoapps/student/tests/test_authz.py b/common/djangoapps/student/tests/test_authz.py index c0b88e6318b5..70636e04b68a 100644 --- a/common/djangoapps/student/tests/test_authz.py +++ b/common/djangoapps/student/tests/test_authz.py @@ -11,6 +11,7 @@ from django.test import TestCase, override_settings from opaque_keys.edx.locator import CourseLocator +from common.djangoapps.student.models.user import CourseAccessRole from common.djangoapps.student.auth import ( add_users, has_studio_read_access, @@ -305,6 +306,23 @@ def test_limited_staff_no_studio_access_cms(self): assert not has_studio_read_access(self.limited_staff, self.course_key) assert not has_studio_write_access(self.limited_staff, self.course_key) + @override_settings(SERVICE_VARIANT='cms') + def test_limited_org_staff_no_studio_access_cms(self): + """ + Verifies that course limited staff have no read and no write access when SERVICE_VARIANT is not 'lms'. + """ + # Add a user as course_limited_staff on the org + # This is not possible using the course roles classes but is possible via Django admin so we + # insert a row into the model directly to test that scenario. + CourseAccessRole.objects.create( + user=self.limited_staff, + org=self.course_key.org, + role=CourseLimitedStaffRole.ROLE, + ) + + assert not has_studio_read_access(self.limited_staff, self.course_key) + assert not has_studio_write_access(self.limited_staff, self.course_key) + class CourseOrgGroupTest(TestCase): """ diff --git a/common/djangoapps/third_party_auth/api/serializers.py b/common/djangoapps/third_party_auth/api/serializers.py index 3e8513de7312..a510cbe07a07 100644 --- a/common/djangoapps/third_party_auth/api/serializers.py +++ b/common/djangoapps/third_party_auth/api/serializers.py @@ -20,4 +20,7 @@ def get_username(self, social_user): def get_remote_id(self, social_user): """ Gets remote id from social user based on provider """ + remote_id_field_name = self.context.get('remote_id_field_name', None) + if remote_id_field_name: + return self.provider.get_remote_id_from_field_name(social_user, remote_id_field_name) return self.provider.get_remote_id_from_social_auth(social_user) diff --git a/common/djangoapps/third_party_auth/api/tests/test_views.py b/common/djangoapps/third_party_auth/api/tests/test_views.py index f7834001d66b..61740268db90 100644 --- a/common/djangoapps/third_party_auth/api/tests/test_views.py +++ b/common/djangoapps/third_party_auth/api/tests/test_views.py @@ -38,8 +38,10 @@ PASSWORD = "edx" -def get_mapping_data_by_usernames(usernames): +def get_mapping_data_by_usernames(usernames, remote_id_field_name=False): """ Generate mapping data used in response """ + if remote_id_field_name: + return [{'username': username, 'remote_id': 'external_' + username} for username in usernames] return [{'username': username, 'remote_id': 'remote_' + username} for username in usernames] @@ -76,11 +78,13 @@ def setUp(self): # pylint: disable=arguments-differ provider=google.backend_name, uid=f'{username}@gmail.com', ) - UserSocialAuth.objects.create( + usa = UserSocialAuth.objects.create( user=user, provider=testshib.backend_name, uid=f'{testshib.slug}:remote_{username}', ) + usa.set_extra_data({'external_user_id': f'external_{username}'}) + usa.refresh_from_db() # Create another user not linked to any providers: UserFactory.create(username=CARL_USERNAME, email=f'{CARL_USERNAME}@example.com', password=PASSWORD) @@ -304,12 +308,20 @@ def test_list_all_user_mappings_oauth2(self, valid_call, expect_code, expect_dat @ddt.data( ({'username': [ALICE_USERNAME, STAFF_USERNAME]}, 200, get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME])), + ({'username': [ALICE_USERNAME, STAFF_USERNAME], 'remote_id_field_name': 'external_user_id'}, 200, + get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME], remote_id_field_name=True)), ({'remote_id': ['remote_' + ALICE_USERNAME, 'remote_' + STAFF_USERNAME, 'remote_' + CARL_USERNAME]}, 200, get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME])), + ({'remote_id': ['remote_' + ALICE_USERNAME, 'remote_' + STAFF_USERNAME, 'remote_' + CARL_USERNAME], + 'remote_id_field_name': 'external_user_id'}, 200, + get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME], remote_id_field_name=True)), ({'username': [ALICE_USERNAME, CARL_USERNAME, STAFF_USERNAME]}, 200, get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME])), ({'username': [ALICE_USERNAME], 'remote_id': ['remote_' + STAFF_USERNAME]}, 200, get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME])), + ({'username': [ALICE_USERNAME], 'remote_id': ['remote_' + STAFF_USERNAME], + 'remote_id_field_name': 'external_user_id'}, 200, + get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME], remote_id_field_name=True)), ) @ddt.unpack def test_user_mappings_with_query_params_comma_separated(self, query_params, expect_code, expect_data): @@ -321,6 +333,8 @@ def test_user_mappings_with_query_params_comma_separated(self, query_params, exp for attr in ['username', 'remote_id']: if attr in query_params: params.append('{}={}'.format(attr, ','.join(query_params[attr]))) + if 'remote_id_field_name' in query_params: + params.append('remote_id_field_name={}'.format(query_params['remote_id_field_name'])) url = "{}?{}".format(base_url, '&'.join(params)) response = self.client.get(url, HTTP_X_EDX_API_KEY=VALID_API_KEY) self._verify_response(response, expect_code, expect_data) @@ -328,12 +342,20 @@ def test_user_mappings_with_query_params_comma_separated(self, query_params, exp @ddt.data( ({'username': [ALICE_USERNAME, STAFF_USERNAME]}, 200, get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME])), + ({'username': [ALICE_USERNAME, STAFF_USERNAME], 'remote_id_field_name': 'external_user_id'}, 200, + get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME], remote_id_field_name=True)), ({'remote_id': ['remote_' + ALICE_USERNAME, 'remote_' + STAFF_USERNAME, 'remote_' + CARL_USERNAME]}, 200, get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME])), + ({'remote_id': ['remote_' + ALICE_USERNAME, 'remote_' + STAFF_USERNAME, 'remote_' + CARL_USERNAME], + 'remote_id_field_name': 'external_user_id'}, 200, + get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME], remote_id_field_name=True)), ({'username': [ALICE_USERNAME, CARL_USERNAME, STAFF_USERNAME]}, 200, get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME])), ({'username': [ALICE_USERNAME], 'remote_id': ['remote_' + STAFF_USERNAME]}, 200, get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME])), + ({'username': [ALICE_USERNAME], 'remote_id': ['remote_' + STAFF_USERNAME], + 'remote_id_field_name': 'external_user_id'}, 200, + get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME], remote_id_field_name=True)), ) @ddt.unpack def test_user_mappings_with_query_params_multi_value_key(self, query_params, expect_code, expect_data): @@ -345,6 +367,8 @@ def test_user_mappings_with_query_params_multi_value_key(self, query_params, exp for attr in ['username', 'remote_id']: if attr in query_params: params.setlist(attr, query_params[attr]) + if 'remote_id_field_name' in query_params: + params['remote_id_field_name'] = query_params['remote_id_field_name'] url = f"{base_url}?{params.urlencode()}" response = self.client.get(url, HTTP_X_EDX_API_KEY=VALID_API_KEY) self._verify_response(response, expect_code, expect_data) diff --git a/common/djangoapps/third_party_auth/api/views.py b/common/djangoapps/third_party_auth/api/views.py index c2b8b0dd6f39..89d55e2eecdc 100644 --- a/common/djangoapps/third_party_auth/api/views.py +++ b/common/djangoapps/third_party_auth/api/views.py @@ -323,6 +323,9 @@ class UserMappingView(ListAPIView): GET /api/third_party_auth/v0/providers/{provider_id}/users?username={username1},{username2} + GET /api/third_party_auth/v0/providers/{provider_id}/users?username={username1}& + remote_id_field_name={external_id_field_name} + GET /api/third_party_auth/v0/providers/{provider_id}/users?username={username1}&usernames={username2} GET /api/third_party_auth/v0/providers/{provider_id}/users?remote_id={remote_id1},{remote_id2} @@ -346,6 +349,9 @@ class UserMappingView(ListAPIView): * usernames: Optional. List of comma separated edX usernames to filter the result set. e.g. ?usernames=bob123,jane456 + * remote_id_field_name: Optional. The field name to use for the remote id lookup. + Useful when learners are coming from external LMS. e.g. ?remote_id_field_name=ext_userid_sf + * page, page_size: Optional. Used for paging the result set, especially when getting an unfiltered list. @@ -415,6 +421,7 @@ def get_serializer_context(self): remove idp_slug from the remote_id if there is any """ context = super().get_serializer_context() + context['remote_id_field_name'] = self.request.query_params.get('remote_id_field_name', None) context['provider'] = self.provider return context diff --git a/common/djangoapps/third_party_auth/docs/how_tos/testing_saml_locally.rst b/common/djangoapps/third_party_auth/docs/how_tos/testing_saml_locally.rst new file mode 100644 index 000000000000..2be1cc1dacf6 --- /dev/null +++ b/common/djangoapps/third_party_auth/docs/how_tos/testing_saml_locally.rst @@ -0,0 +1,156 @@ +Testing SAML Authentication Locally with MockSAML +================================================== + +This guide walks through setting up and testing SAML authentication in a local Open edX devstack environment using MockSAML.com as a test Identity Provider (IdP). + +Overview +-------- + +SAML (Security Assertion Markup Language) authentication in Open edX requires three configuration objects to work together: + +1. **SAMLConfiguration**: Configures the Service Provider (SP) metadata - entity ID, keys, and organization info +2. **SAMLProviderConfig**: Configures a specific Identity Provider (IdP) connection with metadata URL and attribute mappings +3. **SAMLProviderData**: Stores the IdP's metadata (SSO URL, public key) fetched from the IdP's metadata endpoint + +**Critical Requirement**: The SAMLConfiguration object MUST have the slug "default" because this value is hardcoded in the authentication execution path at ``common/djangoapps/third_party_auth/models.py:906``. + +Prerequisites +------------- + +* Local Open edX devstack running +* Access to Django admin at http://localhost:18000/admin/ +* MockSAML.com account (free service for SAML testing) + +Step 1: Configure SAMLConfiguration +------------------------------------ + +The SAMLConfiguration defines your Open edX instance as a SAML Service Provider (SP). + +1. Navigate to Django Admin → Third Party Auth → SAML Configurations +2. Click "Add SAML Configuration" +3. Configure with these **required** values: + + ============ =================================================== + Field Value + ============ =================================================== + Site localhost:18000 + **Slug** **default** (MUST be "default" - hardcoded in code) + Entity ID https://saml.example.com/entityid + Enabled ✓ (checked) + ============ =================================================== + +4. For local testing with MockSAML, you can leave the keys blank. + +5. Optionally configure Organization Info (use default or customize): + + .. code-block:: json + + { + "en-US": { + "url": "http://localhost:18000", + "displayname": "Local Open edX", + "name": "localhost" + } + } + +6. Click "Save" + +Step 2: Configure SAMLProviderConfig +------------------------------------- + +The SAMLProviderConfig connects to a specific SAML Identity Provider (MockSAML in this case). + +1. Navigate to Django Admin → Third Party Auth → Provider Configuration (SAML IdPs) +2. Click "Add Provider Configuration (SAML IdP)" +3. Configure with these values: + + ========================= =================================================== + Field Value + ========================= =================================================== + Name Test Localhost (or any descriptive name) + Slug default (to match test URLs) + Backend Name tpa-saml + Entity ID https://saml.example.com/entityid + Metadata Source https://mocksaml.com/api/saml/metadata + Site localhost:18000 + SAML Configuration Select the SAMLConfiguration created in Step 1 + Enabled ✓ (checked) + Visible ☐ (unchecked for testing) + Skip hinted login dialog ✓ (checked - recommended) + Skip registration form ✓ (checked - recommended) + Skip email verification ✓ (checked - recommended) + Send to registration first ✓ (checked - recommended) + ========================= =================================================== + +4. Leave all attribute mappings (User ID, Email, Full Name, etc.) blank to use defaults +5. Click "Save" + +**Important**: The Entity ID in SAMLProviderConfig MUST match the Entity ID in SAMLConfiguration. + +Step 3: Set IdP Data +-------------------- + +The SAMLProviderData stores metadata from the Identity Provider (MockSAML), create a record with + +* **Entity ID**: https://saml.example.com/entityid +* **SSO URL**: https://mocksaml.com/api/saml/sso +* **Public Key**: The IdP's signing certificate +* **Expires At**: Set to 1 year from fetch time + + +Step 4: Test SAML Authentication +--------------------------------- + +1. Navigate to: http://localhost:18000/auth/idp_redirect/saml-default +2. You should be redirected to MockSAML.com +3. Complete the authentication on MockSAML - just click "Sign In" with whatever is in the form. +4. You should be redirected back to Open edX +5. If this is a new user, you'll see the registration form +6. After registration, you should be logged in + +Expected Behavior +^^^^^^^^^^^^^^^^^ + +1. Initial redirect to MockSAML (https://mocksaml.com/api/saml/sso) +2. MockSAML displays the login page +3. After authentication, MockSAML POSTs the SAML assertion back to Open edX +4. Open edX validates the assertion and creates/logs in the user +5. User is redirected to the dashboard or registration form (if new user) + +Reference Configuration +----------------------- + +Here's a summary of a working test configuration: + +**SAMLConfiguration** (id=6): + +* Site: localhost:18000 +* Slug: **default** +* Entity ID: https://saml.example.com/entityid +* Enabled: True + +**SAMLProviderConfig** (id=11): + +* Name: Test Localhost +* Slug: default +* Entity ID: https://saml.example.com/entityid +* Metadata Source: https://mocksaml.com/api/saml/metadata +* Backend Name: tpa-saml +* Site: localhost:18000 +* SAML Configuration: → SAMLConfiguration (id=6) +* Enabled: True + +**SAMLProviderData** (id=3): + +* Entity ID: https://saml.example.com/entityid +* SSO URL: https://mocksaml.com/api/saml/sso +* Public Key: (certificate from MockSAML metadata) +* Fetched At: 2026-02-27 18:05:40+00:00 +* Expires At: 2027-02-27 18:05:41+00:00 +* Valid: True + +**MockSAML Configuration**: + +* SP Entity ID: https://saml.example.com/entityid +* ACS URL: http://localhost:18000/auth/complete/tpa-saml/ +* Test User Attributes: email, firstName, lastName, uid diff --git a/common/djangoapps/third_party_auth/management/commands/saml.py b/common/djangoapps/third_party_auth/management/commands/saml.py index afe369c2ade0..6865ebf69987 100644 --- a/common/djangoapps/third_party_auth/management/commands/saml.py +++ b/common/djangoapps/third_party_auth/management/commands/saml.py @@ -6,7 +6,6 @@ import logging from django.core.management.base import BaseCommand, CommandError -from edx_django_utils.monitoring import set_custom_attribute from common.djangoapps.third_party_auth.tasks import fetch_saml_metadata from common.djangoapps.third_party_auth.models import SAMLProviderConfig, SAMLConfiguration @@ -71,31 +70,28 @@ def _handle_run_checks(self): """ Handle the --run-checks option for checking SAMLProviderConfig configuration issues. - This is a report-only command. It identifies potential configuration problems such as: - - Outdated SAMLConfiguration references (provider pointing to old config version) - - Site ID mismatches between SAMLProviderConfig and its SAMLConfiguration - - Slug mismatches (except 'default' slugs) # noqa: E501 - - SAMLProviderConfig objects with null SAMLConfiguration references (informational) - - Includes observability attributes for monitoring. + This is a report-only command that identifies potential configuration problems. """ - # Set custom attributes for monitoring the check operation - # .. custom_attribute_name: saml_management_command.operation - # .. custom_attribute_description: Records current SAML operation ('run_checks'). - set_custom_attribute('saml_management_command.operation', 'run_checks') - metrics = self._check_provider_configurations() self._report_check_summary(metrics) def _check_provider_configurations(self): """ - Check each provider configuration for potential issues. + Check each provider configuration for potential issues: + - Outdated configuration references + - Site ID mismatches + - Missing configurations (no direct config and no default) + - Disabled providers and configurations + Also reports informational data such as slug mismatches. + + See code comments near each log output for possible resolution details. Returns a dictionary of metrics about the found issues. """ outdated_count = 0 site_mismatch_count = 0 slug_mismatch_count = 0 null_config_count = 0 + disabled_config_count = 0 error_count = 0 total_providers = 0 @@ -107,53 +103,74 @@ def _check_provider_configurations(self): for provider_config in provider_configs: total_providers += 1 + + # Check if provider is disabled + provider_disabled = not provider_config.enabled + disabled_status = ", enabled=False" if provider_disabled else "" + provider_info = ( - f"Provider (id={provider_config.id}, name={provider_config.name}, " - f"slug={provider_config.slug}, site_id={provider_config.site_id})" + f"Provider (id={provider_config.id}, " + f"name={provider_config.name}, slug={provider_config.slug}, " + f"site_id={provider_config.site_id}{disabled_status})" ) - if not provider_config.saml_configuration: - self.stdout.write( - f"[INFO] {provider_info} has no SAML configuration because " - "a matching default was not found." - ) - null_config_count += 1 - continue + # Provider disabled status is already included in provider_info format try: + if not provider_config.saml_configuration: + null_config_count, disabled_config_count = self._check_no_config( + provider_config, provider_info, null_config_count, disabled_config_count + ) + continue + + # Check if SAML configuration is disabled + if not provider_config.saml_configuration.enabled: + # Resolution: Enable the SAML configuration in Django admin + # or assign a different configuration + self.stdout.write( + f"[WARNING] {provider_info} " + f"has SAML config (id={provider_config.saml_configuration_id}, enabled=False)." + ) + disabled_config_count += 1 + + # Check configuration currency current_config = SAMLConfiguration.current( provider_config.saml_configuration.site_id, provider_config.saml_configuration.slug ) - # Check for outdated configuration references - if current_config: - if current_config.id != provider_config.saml_configuration_id: - self.stdout.write( - f"[WARNING] {provider_info} " - f"has outdated SAML config (id={provider_config.saml_configuration_id} which " - f"should be updated to the current SAML config (id={current_config.id})." - ) - outdated_count += 1 + if current_config and (current_config.id != provider_config.saml_configuration_id): + # Resolution: Update the provider's saml_configuration_id to the current config ID + self.stdout.write( + f"[WARNING] {provider_info} " + f"has outdated SAML config (id={provider_config.saml_configuration_id}) which " + f"should be updated to the current SAML config (id={current_config.id})." + ) + outdated_count += 1 + # Check site ID match if provider_config.saml_configuration.site_id != provider_config.site_id: config_site_id = provider_config.saml_configuration.site_id - provider_site_id = provider_config.site_id + # Resolution: Create a new SAML configuration for the correct site + # or move the provider to the matching site self.stdout.write( f"[WARNING] {provider_info} " - f"SAML config (id={provider_config.saml_configuration_id}, site_id={config_site_id}) " - "does not match the provider's site_id." + f"SAML config (id={provider_config.saml_configuration_id}, " + f"site_id={config_site_id}) does not match the provider's site_id." ) site_mismatch_count += 1 - saml_configuration_slug = provider_config.saml_configuration.slug - provider_config_slug = provider_config.slug - - if saml_configuration_slug not in (provider_config_slug, 'default'): + # Check slug match + if provider_config.saml_configuration.slug not in (provider_config.slug, 'default'): + config_id = provider_config.saml_configuration_id + saml_configuration_slug = provider_config.saml_configuration.slug + config_disabled_status = ", enabled=False" if not provider_config.saml_configuration.enabled else "" + # Resolution: This is informational only - provider can use + # a different slug configuration self.stdout.write( - f"[WARNING] {provider_info} " - f"SAML config (id={provider_config.saml_configuration_id}, slug='{saml_configuration_slug}') " - "does not match the provider's slug." + f"[INFO] {provider_info} has " + f"SAML config (id={config_id}, slug='{saml_configuration_slug}'{config_disabled_status}) " + "that does not match the provider's slug." ) slug_mismatch_count += 1 @@ -165,41 +182,64 @@ def _check_provider_configurations(self): 'total_providers': {'count': total_providers, 'requires_attention': False}, 'outdated_count': {'count': outdated_count, 'requires_attention': True}, 'site_mismatch_count': {'count': site_mismatch_count, 'requires_attention': True}, - 'slug_mismatch_count': {'count': slug_mismatch_count, 'requires_attention': True}, + 'slug_mismatch_count': {'count': slug_mismatch_count, 'requires_attention': False}, 'null_config_count': {'count': null_config_count, 'requires_attention': False}, + 'disabled_config_count': {'count': disabled_config_count, 'requires_attention': True}, 'error_count': {'count': error_count, 'requires_attention': True}, } - for key, metric_data in metrics.items(): - # .. custom_attribute_name: saml_management_command.{key} - # .. custom_attribute_description: Records metrics from SAML configuration checks. - set_custom_attribute(f'saml_management_command.{key}', metric_data['count']) - return metrics + def _check_no_config(self, provider_config, provider_info, null_config_count, disabled_config_count): + """Helper to check providers with no direct SAML configuration.""" + default_config = SAMLConfiguration.current(provider_config.site_id, 'default') + if not default_config or default_config.id is None: + # Resolution: Create/Link a SAML configuration for this provider + # or create/link a default configuration for the site + self.stdout.write( + f"[WARNING] {provider_info} has no direct SAML configuration and " + "no matching default configuration was found." + ) + null_config_count += 1 + + elif not default_config.enabled: + # Resolution: Enable the provider's linked SAML configuration + # or create/link a specific configuration for this provider + self.stdout.write( + f"[WARNING] {provider_info} has no direct SAML configuration and " + f"the default configuration (id={default_config.id}, enabled=False)." + ) + disabled_config_count += 1 + + return null_config_count, disabled_config_count + def _report_check_summary(self, metrics): """ - Print a summary of the check results and set the total_requiring_attention custom attribute. + Print a summary of the check results. """ total_requiring_attention = sum( metric_data['count'] for metric_data in metrics.values() if metric_data['requires_attention'] ) - # .. custom_attribute_name: saml_management_command.total_requiring_attention - # .. custom_attribute_description: The total number of configuration issues requiring attention. - set_custom_attribute('saml_management_command.total_requiring_attention', total_requiring_attention) - self.stdout.write(self.style.SUCCESS("CHECK SUMMARY:")) self.stdout.write(f" Providers checked: {metrics['total_providers']['count']}") - self.stdout.write(f" Null configs: {metrics['null_config_count']['count']}") + self.stdout.write("") + + # Informational only section + self.stdout.write("Informational only:") + self.stdout.write(f" Slug mismatches: {metrics['slug_mismatch_count']['count']}") + self.stdout.write(f" Missing configs: {metrics['null_config_count']['count']}") + self.stdout.write("") + # Issues requiring attention section if total_requiring_attention > 0: - self.stdout.write("\nIssues requiring attention:") + self.stdout.write("Issues requiring attention:") self.stdout.write(f" Outdated: {metrics['outdated_count']['count']}") self.stdout.write(f" Site mismatches: {metrics['site_mismatch_count']['count']}") - self.stdout.write(f" Slug mismatches: {metrics['slug_mismatch_count']['count']}") + self.stdout.write(f" Disabled configs: {metrics['disabled_config_count']['count']}") self.stdout.write(f" Errors: {metrics['error_count']['count']}") - self.stdout.write(f"\nTotal issues requiring attention: {total_requiring_attention}") + self.stdout.write("") + self.stdout.write(f"Total issues requiring attention: {total_requiring_attention}") else: - self.stdout.write(self.style.SUCCESS("\nNo configuration issues found!")) + self.stdout.write(self.style.SUCCESS("No configuration issues found!")) diff --git a/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py b/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py index 6963d5dcd0d5..d80c9146664b 100644 --- a/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py +++ b/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py @@ -79,6 +79,7 @@ def setUp(self): name='TestShib College', entity_id='https://idp.testshib.org/idp/shibboleth', metadata_source='https://www.testshib.org/metadata/testshib-providers.xml', + saml_configuration=self.saml_config, ) def _setup_test_configs_for_run_checks(self): @@ -337,8 +338,30 @@ def _run_checks_command(self): call_command('saml', '--run-checks', stdout=out) return out.getvalue() - @mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute') - def test_run_checks_outdated_configs(self, mock_set_custom_attribute): + def test_run_checks_setup_test_data(self): + """ + Test the --run-checks command against initial setup test data. + + This test validates that the base setup data (from setUp) is correctly + identified as having configuration issues. The setup includes a provider + (self.provider_config) with a disabled SAML configuration (self.saml_config), + which is reported as a disabled config issue (not a missing config). + """ + output = self._run_checks_command() + + # The setup data includes a provider with a disabled SAML config + expected_warning = ( + f'[WARNING] Provider (id={self.provider_config.id}, ' + f'name={self.provider_config.name}, ' + f'slug={self.provider_config.slug}, ' + f'site_id={self.provider_config.site_id}) ' + f'has SAML config (id={self.saml_config.id}, enabled=False).' + ) + self.assertIn(expected_warning, output) + self.assertIn('Missing configs: 0', output) # No missing configs from setUp + self.assertIn('Disabled configs: 1', output) # From setUp: provider_config with disabled saml_config + + def test_run_checks_outdated_configs(self): """ Test the --run-checks command identifies outdated configurations. """ @@ -346,31 +369,18 @@ def test_run_checks_outdated_configs(self, mock_set_custom_attribute): output = self._run_checks_command() - self.assertIn('[WARNING]', output) - self.assertIn('test-provider', output) - self.assertIn( - f'id={old_config.id} which should be updated to the current SAML config (id={new_config.id})', - output + expected_warning = ( + f'[WARNING] Provider (id={test_provider_config.id}, name={test_provider_config.name}, ' + f'slug={test_provider_config.slug}, site_id={test_provider_config.site_id}) ' + f'has outdated SAML config (id={old_config.id}) which should be updated to ' + f'the current SAML config (id={new_config.id}).' ) - self.assertIn('CHECK SUMMARY:', output) - self.assertIn('Providers checked: 2', output) + self.assertIn(expected_warning, output) self.assertIn('Outdated: 1', output) + # Total includes: 1 outdated + 2 disabled configs (setUp + test's old_config which is also disabled) + self.assertIn('Total issues requiring attention: 3', output) - # Check key observability calls - expected_calls = [ - mock.call('saml_management_command.operation', 'run_checks'), - mock.call('saml_management_command.total_providers', 2), - mock.call('saml_management_command.outdated_count', 1), - mock.call('saml_management_command.site_mismatch_count', 0), - mock.call('saml_management_command.slug_mismatch_count', 1), - mock.call('saml_management_command.null_config_count', 1), - mock.call('saml_management_command.error_count', 0), - mock.call('saml_management_command.total_requiring_attention', 2), - ] - mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False) - - @mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute') - def test_run_checks_site_mismatches(self, mock_set_custom_attribute): + def test_run_checks_site_mismatches(self): """ Test the --run-checks command identifies site ID mismatches. """ @@ -380,7 +390,7 @@ def test_run_checks_site_mismatches(self, mock_set_custom_attribute): entity_id='https://example.com' ) - SAMLProviderConfigFactory.create( + provider = SAMLProviderConfigFactory.create( site=self.site, slug='test-provider', saml_configuration=config @@ -388,25 +398,17 @@ def test_run_checks_site_mismatches(self, mock_set_custom_attribute): output = self._run_checks_command() - self.assertIn('[WARNING]', output) - self.assertIn('test-provider', output) - self.assertIn('does not match the provider\'s site_id', output) - - # Check observability calls - expected_calls = [ - mock.call('saml_management_command.operation', 'run_checks'), - mock.call('saml_management_command.total_providers', 2), - mock.call('saml_management_command.outdated_count', 0), - mock.call('saml_management_command.site_mismatch_count', 1), - mock.call('saml_management_command.slug_mismatch_count', 1), - mock.call('saml_management_command.null_config_count', 1), - mock.call('saml_management_command.error_count', 0), - mock.call('saml_management_command.total_requiring_attention', 2), - ] - mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False) - - @mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute') - def test_run_checks_slug_mismatches(self, mock_set_custom_attribute): + expected_warning = ( + f'[WARNING] Provider (id={provider.id}, name={provider.name}, ' + f'slug={provider.slug}, site_id={provider.site_id}) ' + f'SAML config (id={config.id}, site_id={config.site_id}) does not match the provider\'s site_id.' + ) + self.assertIn(expected_warning, output) + self.assertIn('Site mismatches: 1', output) + # Total includes: 1 site mismatch + 1 disabled config (from setUp) + self.assertIn('Total issues requiring attention: 2', output) + + def test_run_checks_slug_mismatches(self): """ Test the --run-checks command identifies slug mismatches. """ @@ -416,7 +418,7 @@ def test_run_checks_slug_mismatches(self, mock_set_custom_attribute): entity_id='https://example.com' ) - SAMLProviderConfigFactory.create( + provider = SAMLProviderConfigFactory.create( site=self.site, slug='provider-slug', saml_configuration=config @@ -424,29 +426,23 @@ def test_run_checks_slug_mismatches(self, mock_set_custom_attribute): output = self._run_checks_command() - self.assertIn('[WARNING]', output) - self.assertIn('provider-slug', output) - self.assertIn('does not match the provider\'s slug', output) - - # Check observability calls - expected_calls = [ - mock.call('saml_management_command.operation', 'run_checks'), - mock.call('saml_management_command.total_providers', 2), - mock.call('saml_management_command.outdated_count', 0), - mock.call('saml_management_command.site_mismatch_count', 0), - mock.call('saml_management_command.slug_mismatch_count', 1), - mock.call('saml_management_command.null_config_count', 1), - mock.call('saml_management_command.error_count', 0), - mock.call('saml_management_command.total_requiring_attention', 1), - ] - mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False) - - @mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute') - def test_run_checks_null_configurations(self, mock_set_custom_attribute): + expected_info = ( + f'[INFO] Provider (id={provider.id}, name={provider.name}, ' + f'slug={provider.slug}, site_id={provider.site_id}) ' + f'has SAML config (id={config.id}, slug=\'{config.slug}\') ' + f'that does not match the provider\'s slug.' + ) + self.assertIn(expected_info, output) + self.assertIn('Slug mismatches: 1', output) + + def test_run_checks_null_configurations(self): """ Test the --run-checks command identifies providers with null configurations. + This test verifies that providers with no direct SAML configuration and no + default configuration available are properly reported. """ - SAMLProviderConfigFactory.create( + # Create a provider with no SAML configuration on a site that has no default config + provider = SAMLProviderConfigFactory.create( site=self.site, slug='null-provider', saml_configuration=None @@ -454,19 +450,101 @@ def test_run_checks_null_configurations(self, mock_set_custom_attribute): output = self._run_checks_command() - self.assertIn('[INFO]', output) - self.assertIn('null-provider', output) - self.assertIn('has no SAML configuration because a matching default was not found', output) - - # Check observability calls - expected_calls = [ - mock.call('saml_management_command.operation', 'run_checks'), - mock.call('saml_management_command.total_providers', 2), - mock.call('saml_management_command.outdated_count', 0), - mock.call('saml_management_command.site_mismatch_count', 0), - mock.call('saml_management_command.slug_mismatch_count', 0), - mock.call('saml_management_command.null_config_count', 2), - mock.call('saml_management_command.error_count', 0), - mock.call('saml_management_command.total_requiring_attention', 0), - ] - mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False) + expected_warning = ( + f'[WARNING] Provider (id={provider.id}, name={provider.name}, ' + f'slug={provider.slug}, site_id={provider.site_id}) ' + f'has no direct SAML configuration and no matching default configuration was found.' + ) + self.assertIn(expected_warning, output) + self.assertIn('Missing configs: 1', output) # 1 from this test (provider with no config and no default) + self.assertIn('Disabled configs: 1', output) # 1 from setUp data + + def test_run_checks_null_config_id(self): + """ + Test the --run-checks command identifies providers with disabled default configurations. + When a provider has no direct SAML configuration and the default config is disabled, + it should be reported as a missing config issue. + """ + # Create a disabled default configuration for this site + disabled_default_config = SAMLConfigurationFactory.create( + site=self.site, + slug='default', + entity_id='https://default.example.com', + enabled=False + ) + + # Create a provider with no direct SAML configuration + # It will fall back to the disabled default config + provider = SAMLProviderConfigFactory.create( + site=self.site, + slug='null-id-provider', + saml_configuration=None + ) + + output = self._run_checks_command() + + expected_warning = ( + f'[WARNING] Provider (id={provider.id}, name={provider.name}, ' + f'slug={provider.slug}, site_id={provider.site_id}) ' + f'has no direct SAML configuration and the default configuration ' + f'(id={disabled_default_config.id}, enabled=False).' + ) + self.assertIn(expected_warning, output) + self.assertIn('Missing configs: 0', output) # No missing configs since default config exists + self.assertIn('Disabled configs: 2', output) # 1 from this test + 1 from setUp data + + def test_run_checks_with_default_config(self): + """ + Test the --run-checks command correctly handles providers with default configurations. + """ + provider = SAMLProviderConfigFactory.create( + site=self.site, + slug='default-config-provider', + saml_configuration=None + ) + + default_config = SAMLConfigurationFactory.create( + site=self.site, + slug='default', + entity_id='https://default.example.com' + ) + + output = self._run_checks_command() + + self.assertIn('Missing configs: 0', output) # This tests provider has valid default config + self.assertIn('Disabled configs: 1', output) # From setUp + + def test_run_checks_disabled_functionality(self): + """ + Test the --run-checks command handles disabled providers and configurations. + """ + disabled_provider = SAMLProviderConfigFactory.create( + site=self.site, + slug='disabled-provider', + enabled=False + ) + + disabled_config = SAMLConfigurationFactory.create( + site=self.site, + slug='disabled-config', + enabled=False + ) + + provider_with_disabled_config = SAMLProviderConfigFactory.create( + site=self.site, + slug='provider-with-disabled-config', + saml_configuration=disabled_config + ) + + output = self._run_checks_command() + + expected_warning = ( + f'[WARNING] Provider (id={provider_with_disabled_config.id}, ' + f'name={provider_with_disabled_config.name}, ' + f'slug={provider_with_disabled_config.slug}, ' + f'site_id={provider_with_disabled_config.site_id}) ' + f'has SAML config (id={disabled_config.id}, enabled=False).' + ) + self.assertIn(expected_warning, output) + self.assertIn('Missing configs: 1', output) # disabled_provider has no config + self.assertIn('Disabled configs: 2', output) # setUp's provider + provider_with_disabled_config diff --git a/common/djangoapps/third_party_auth/migrations/0014_samlproviderconfig_optional_email_checkboxes.py b/common/djangoapps/third_party_auth/migrations/0014_samlproviderconfig_optional_email_checkboxes.py new file mode 100644 index 000000000000..34fcf3c97b58 --- /dev/null +++ b/common/djangoapps/third_party_auth/migrations/0014_samlproviderconfig_optional_email_checkboxes.py @@ -0,0 +1,26 @@ +# Generated migration for adding optional checkbox skip configuration field + +from django.db import migrations, models +import django.utils.translation + + +class Migration(migrations.Migration): + + dependencies = [ + ('third_party_auth', '0013_default_site_id_wrapper_function'), + ] + + operations = [ + migrations.AddField( + model_name='samlproviderconfig', + name='skip_registration_optional_checkboxes', + field=models.BooleanField( + default=False, + help_text=django.utils.translation.gettext_lazy( + "If enabled, optional checkboxes (marketing emails opt-in, etc.) will not be rendered " + "on the registration form for users registering via this provider. When these checkboxes " + "are skipped, their values are inferred as False (opted out)." + ), + ), + ), + ] diff --git a/common/djangoapps/third_party_auth/models.py b/common/djangoapps/third_party_auth/models.py index 6d244d96eddd..6ae816674b07 100644 --- a/common/djangoapps/third_party_auth/models.py +++ b/common/djangoapps/third_party_auth/models.py @@ -745,6 +745,14 @@ class SAMLProviderConfig(ProviderConfig): "immediately after authenticating with the third party instead of the login page." ), ) + skip_registration_optional_checkboxes = models.BooleanField( + default=False, + help_text=_( + "If enabled, optional checkboxes (marketing emails opt-in, etc.) will not be rendered " + "on the registration form for users registering via this provider. When these checkboxes " + "are skipped, their values are inferred as False (opted out)." + ), + ) other_settings = models.TextField( verbose_name="Advanced settings", blank=True, help_text=( @@ -803,13 +811,27 @@ def get_url_params(self): def is_active_for_pipeline(self, pipeline): """ Is this provider being used for the specified pipeline? """ - return self.backend_name == pipeline['backend'] and self.slug == pipeline['kwargs']['response']['idp_name'] + try: + return self.backend_name == pipeline['backend'] and self.slug == pipeline['kwargs']['response']['idp_name'] + except KeyError: + return False def match_social_auth(self, social_auth): """ Is this provider being used for this UserSocialAuth entry? """ prefix = self.slug + ":" return self.backend_name == social_auth.provider and social_auth.uid.startswith(prefix) + def get_remote_id_from_field_name(self, social_auth, field_name): + """ Given a UserSocialAuth object, return the user remote ID against the field name provided. """ + if not self.match_social_auth(social_auth): + raise ValueError( + f"UserSocialAuth record does not match given provider {self.provider_id}" + ) + field_value = social_auth.extra_data.get(field_name, None) + if field_value and isinstance(field_value, list): + return field_value[0] + return field_value + def get_remote_id_from_social_auth(self, social_auth): """ Given a UserSocialAuth object, return the remote ID used by this provider. """ assert self.match_social_auth(social_auth) diff --git a/common/djangoapps/third_party_auth/pipeline.py b/common/djangoapps/third_party_auth/pipeline.py index 496cfce93c1f..d620d27f5357 100644 --- a/common/djangoapps/third_party_auth/pipeline.py +++ b/common/djangoapps/third_party_auth/pipeline.py @@ -62,6 +62,7 @@ def B(*args, **kwargs): import hashlib import hmac import json +import urllib.parse from collections import OrderedDict from logging import getLogger from smtplib import SMTPException @@ -101,6 +102,10 @@ def B(*args, **kwargs): is_saml_provider, user_exists, ) +from common.djangoapps.third_party_auth.toggles import ( + is_saml_provider_site_fallback_enabled, + is_tpa_next_url_on_dispatch_enabled, +) from common.djangoapps.track import segment from common.djangoapps.util.json_request import JsonResponse @@ -358,7 +363,11 @@ def get_complete_url(backend_name): ValueError: if no provider is enabled with the given backend_name. """ if not any(provider.Registry.get_enabled_by_backend_name(backend_name)): - raise ValueError('Provider with backend %s not enabled' % backend_name) + # When the SAML site-fallback flag is on, the provider may not be visible to the + # site-filtered registry even though SAML auth already completed via a + # site-independent lookup. Allow get_complete_url to proceed in that case. + if not (is_saml_provider_site_fallback_enabled() and backend_name == 'tpa-saml'): + raise ValueError('Provider with backend %s not enabled' % backend_name) return _get_url('social:complete', backend_name) @@ -576,13 +585,23 @@ def ensure_user_information(strategy, auth_entry, backend=None, user=None, socia # It is important that we always execute the entire pipeline. Even if # behavior appears correct without executing a step, it means important # invariants have been violated and future misbehavior is likely. + def _build_redirect_url(base_url): + """Append ?next=… to the redirect URL if the session carries a destination.""" + if not is_tpa_next_url_on_dispatch_enabled(): + return base_url + next_url = strategy.session_get('next') + if next_url and isinstance(next_url, str): + separator = '&' if '?' in base_url else '?' + base_url = f'{base_url}{separator}next={urllib.parse.quote(next_url)}' + return base_url + def dispatch_to_login(): """Redirects to the login page.""" - return redirect(AUTH_DISPATCH_URLS[AUTH_ENTRY_LOGIN]) + return redirect(_build_redirect_url(AUTH_DISPATCH_URLS[AUTH_ENTRY_LOGIN])) def dispatch_to_register(): """Redirects to the registration page.""" - return redirect(AUTH_DISPATCH_URLS[AUTH_ENTRY_REGISTER]) + return redirect(_build_redirect_url(AUTH_DISPATCH_URLS[AUTH_ENTRY_REGISTER])) def should_force_account_creation(): """ For some third party providers, we auto-create user accounts """ @@ -603,13 +622,35 @@ def is_provider_saml(): if not user: # Use only email for user existence check in case of saml provider - if is_provider_saml(): + _is_saml = is_provider_saml() + _provider_obj = provider.Registry.get_from_pipeline({'backend': current_partial.backend, 'kwargs': kwargs}) + logger.info( + '[THIRD_PARTY_AUTH] ensure_user_information: auth_entry=%s backend=%s is_provider_saml=%s ' + 'current_provider=%s skip_email_verification=%s send_to_registration_first=%s ' + 'email=%s kwargs_response_keys=%s', + auth_entry, + current_partial.backend, + _is_saml, + _provider_obj.provider_id if _provider_obj else None, + _provider_obj.skip_email_verification if _provider_obj else None, + _provider_obj.send_to_registration_first if _provider_obj else None, + details.get('email') if details else None, + list((kwargs.get('response') or {}).keys()), + ) + if _is_saml: user_details = {'email': details.get('email')} if details else None else: user_details = details - if user_exists(user_details or {}): + _user_exists = user_exists(user_details or {}) + logger.info( + '[THIRD_PARTY_AUTH] ensure_user_information: user_exists=%s user_details_email=%s', + _user_exists, + (user_details or {}).get('email'), + ) + if _user_exists: # User has not already authenticated and the details sent over from # identity provider belong to an existing user. + logger.info('[THIRD_PARTY_AUTH] ensure_user_information: dispatching to login (user exists)') return dispatch_to_login() if is_api(auth_entry): @@ -617,8 +658,14 @@ def is_provider_saml(): elif auth_entry == AUTH_ENTRY_LOGIN: # User has authenticated with the third party provider but we don't know which edX # account corresponds to them yet, if any. - if should_force_account_creation(): + _force = should_force_account_creation() + logger.info( + '[THIRD_PARTY_AUTH] ensure_user_information: AUTH_ENTRY_LOGIN should_force_account_creation=%s', + _force, + ) + if _force: return dispatch_to_register() + logger.info('[THIRD_PARTY_AUTH] ensure_user_information: dispatching to login (no force create)') return dispatch_to_login() elif auth_entry == AUTH_ENTRY_REGISTER: # User has authenticated with the third party provider and now wants to finish @@ -1009,7 +1056,7 @@ def get_username(strategy, details, backend, user=None, *args, **kwargs): # lin else: slug_func = lambda val: val - if is_auto_generated_username_enabled(): + if is_auto_generated_username_enabled() and details.get('username') is None: username = get_auto_generated_username(details) else: if email_as_username and details.get('email'): diff --git a/common/djangoapps/third_party_auth/provider.py b/common/djangoapps/third_party_auth/provider.py index 5eaf0f1888de..3cb9f8642158 100644 --- a/common/djangoapps/third_party_auth/provider.py +++ b/common/djangoapps/third_party_auth/provider.py @@ -16,6 +16,7 @@ SAMLConfiguration, SAMLProviderConfig ) +from common.djangoapps.third_party_auth.toggles import is_saml_provider_site_fallback_enabled class Registry: @@ -97,6 +98,19 @@ def get_from_pipeline(cls, running_pipeline): if enabled.is_active_for_pipeline(running_pipeline): return enabled + # Fallback for SAML: SAMLAuthBackend.get_idp() uses SAMLProviderConfig.current() + # which has no site check. If the provider's site_id doesn't match the current + # site (or SAMLConfiguration isn't enabled for the current site), _enabled_providers() + # won't yield it — but the SAML handshake already completed. Look up the provider + # directly by idp_name so that pipeline steps like should_force_account_creation() + # can still read provider flags. + if is_saml_provider_site_fallback_enabled() and running_pipeline.get('backend') == 'tpa-saml': + try: + idp_name = running_pipeline['kwargs']['response']['idp_name'] + return SAMLProviderConfig.current(idp_name) + except (KeyError, SAMLProviderConfig.DoesNotExist): + pass + @classmethod def get_enabled_by_backend_name(cls, backend_name): """Generator returning all enabled providers that use the specified diff --git a/common/djangoapps/third_party_auth/saml.py b/common/djangoapps/third_party_auth/saml.py index 8e78f9e36fc9..3f1cdf30a9e1 100644 --- a/common/djangoapps/third_party_auth/saml.py +++ b/common/djangoapps/third_party_auth/saml.py @@ -4,6 +4,7 @@ import logging +from urllib.parse import unquote from copy import deepcopy import requests @@ -17,6 +18,7 @@ from social_core.exceptions import AuthForbidden, AuthMissingParameter from openedx.core.djangoapps.theming.helpers import get_current_request +from openedx.core.djangoapps.user_authn.utils import is_safe_login_or_logout_redirect from common.djangoapps.third_party_auth.exceptions import IncorrectConfigurationException STANDARD_SAML_PROVIDER_KEY = 'standard_saml_provider' @@ -89,12 +91,71 @@ def auth_complete(self, *args, **kwargs): """ Handle exceptions that happen during SAML authentication """ + # For IdP-initiated flows (where the user doesn't first hit /auth/login/...), + # allow callers to provide a post-auth redirect by packing it into RelayState. + # Store it in the session so the rest of the pipeline behaves consistently. + try: + request = get_current_request() + # Allow RelayState to carry both IdP slug and a post-auth destination. + # Format: "|", where is typically a relative LMS path. + self._maybe_set_next_url_from_relay_state(request) + except Exception: # pylint: disable=broad-exception-caught # pragma: no cover + # Never fail auth due to redirect bookkeeping. + pass + try: return super().auth_complete(*args, **kwargs) # We are seeing errors of MultiValueDictKeyError looking for the parameter 'RelayState'. # We would like to have a more specific error to handle for observability purposes. except MultiValueDictKeyError as e: - raise AuthMissingParameter(self.name, e.args[0]) from e + raise AuthMissingParameter(self.name, e.args[0] if e.args else '') from e + + @staticmethod + def _maybe_set_next_url_from_relay_state(request): + """Optionally extract a safe `next` from RelayState and rewrite RelayState to the IdP slug. + + This is specifically to support IdP-initiated flows where Auth0 (and some IdPs) can only + reliably influence the SAML POST via RelayState. + """ + if request is None or not hasattr(request, 'POST'): + return + if not hasattr(request, 'session'): + return + + relay_state = None + try: + relay_state = request.POST.get('RelayState') + except Exception: # pylint: disable=broad-exception-caught # pragma: no cover + relay_state = None + + if not relay_state or '|' not in str(relay_state): + return + + slug_part, next_part = str(relay_state).split('|', 1) + slug_part = slug_part.strip() + next_part = next_part.strip() + if not slug_part or not next_part: + return + + # URL-decode next (Auth0 or callers may URL-encode it). + next_decoded = unquote(next_part) + + # Only store next if it's safe per existing Open edX redirect policy. + if is_safe_login_or_logout_redirect( + redirect_to=next_decoded, + request_host=request.get_host(), + dot_client_id=(request.GET.get('client_id') if hasattr(request, 'GET') else None), + require_https=request.is_secure(), + ): + request.session['next'] = next_decoded + else: + # RelayState included an unsafe destination; clear any stale 'next' value + request.session.pop('next', None) + + # Always rewrite RelayState to just the IdP slug so the SAML backend can locate the provider. + post_copy = request.POST.copy() + post_copy['RelayState'] = slug_part + request._post = post_copy # pylint: disable=protected-access def get_user_id(self, details, response): """ diff --git a/common/djangoapps/third_party_auth/tests/test_pipeline_integration.py b/common/djangoapps/third_party_auth/tests/test_pipeline_integration.py index 0d0a2cf7241d..13c6749b0851 100644 --- a/common/djangoapps/third_party_auth/tests/test_pipeline_integration.py +++ b/common/djangoapps/third_party_auth/tests/test_pipeline_integration.py @@ -379,6 +379,92 @@ def test_redirect_for_saml_based_on_email_only(self, email, expected_redirect_ur assert response.url == expected_redirect_url +@ddt.ddt +class EnsureUserInformationNextUrlTestCase(test.TestCase): + """Tests that ensure_user_information forwards session['next'] as a query parameter.""" + + def _call_ensure_user_information(self, session_next, auth_entry=pipeline.AUTH_ENTRY_LOGIN, + send_to_registration_first=True): + """Helper to call ensure_user_information with a controlled session_get('next') value.""" + mock_provider = mock.MagicMock( + send_to_registration_first=send_to_registration_first, + skip_email_verification=False, + ) + with mock.patch( + 'common.djangoapps.third_party_auth.pipeline.provider.Registry.get_from_pipeline' + ) as get_from_pipeline: + get_from_pipeline.return_value = mock_provider + with mock.patch('social_core.pipeline.partial.partial_prepare') as partial_prepare: + partial_prepare.return_value = mock.MagicMock(token='') + strategy = mock.MagicMock() + strategy.session_get.side_effect = lambda key, *args: ( + session_next if key == 'next' else mock.DEFAULT + ) + response = pipeline.ensure_user_information( + strategy=strategy, + backend=None, + auth_entry=auth_entry, + pipeline_index=0, + ) + return response + + @mock.patch( + 'common.djangoapps.third_party_auth.pipeline.is_tpa_next_url_on_dispatch_enabled', + return_value=True, + ) + @ddt.data( + # (session_next, send_to_registration_first, expected_url) + ('/courses/my-course', True, '/register?next=/courses/my-course'), + ('/courses/my-course', False, '/login?next=/courses/my-course'), + ('/dashboard', True, '/register?next=/dashboard'), + ) + @ddt.unpack + def test_next_url_forwarded_to_redirect(self, session_next, send_to_registration_first, expected_url, _flag_mock): + """When session contains a 'next' URL, it should be appended as a query parameter.""" + response = self._call_ensure_user_information( + session_next=session_next, + send_to_registration_first=send_to_registration_first, + ) + assert response.status_code == 302 + assert response.url == expected_url + + @mock.patch( + 'common.djangoapps.third_party_auth.pipeline.is_tpa_next_url_on_dispatch_enabled', + return_value=True, + ) + @ddt.data(None, '') + def test_no_next_url_gives_bare_redirect(self, session_next, _flag_mock): + """When session has no 'next' URL, the redirect should be bare /register.""" + response = self._call_ensure_user_information(session_next=session_next) + assert response.status_code == 302 + assert response.url == '/register' + + @mock.patch( + 'common.djangoapps.third_party_auth.pipeline.is_tpa_next_url_on_dispatch_enabled', + return_value=True, + ) + def test_next_url_with_special_characters_is_encoded(self, _flag_mock): + """Special characters in the next URL should be percent-encoded.""" + response = self._call_ensure_user_information( + session_next='/courses/my course?foo=bar&baz=1', + ) + assert response.status_code == 302 + assert response.url.startswith('/register?next=') + # The space and & should be encoded + assert '%20' in response.url or '+' in response.url + assert 'foo%3Dbar' in response.url or 'foo=bar' in response.url + + @mock.patch( + 'common.djangoapps.third_party_auth.pipeline.is_tpa_next_url_on_dispatch_enabled', + return_value=False, + ) + def test_flag_disabled_gives_bare_redirect(self, _flag_mock): + """When the waffle flag is disabled, the redirect should be bare even with session['next'].""" + response = self._call_ensure_user_information(session_next='/courses/my-course') + assert response.status_code == 302 + assert response.url == '/register' + + class UserDetailsForceSyncTestCase(TestCase): """Tests to ensure learner profile data is properly synced if the provider requires it.""" diff --git a/common/djangoapps/third_party_auth/tests/test_saml.py b/common/djangoapps/third_party_auth/tests/test_saml.py index 6b966a3e6ea4..3a34a4dabd26 100644 --- a/common/djangoapps/third_party_auth/tests/test_saml.py +++ b/common/djangoapps/third_party_auth/tests/test_saml.py @@ -5,7 +5,9 @@ from unittest import mock +from django.test import RequestFactory from django.utils.datastructures import MultiValueDictKeyError +from django.contrib.sessions.middleware import SessionMiddleware from social_core.exceptions import AuthMissingParameter from common.djangoapps.third_party_auth.saml import EdXSAMLIdentityProvider, get_saml_idp_class, SAMLAuthBackend @@ -40,6 +42,14 @@ def test_get_user_details(self): class TestSAMLAuthBackend(SAMLTestCase): """ Tests for the SAML backend. """ + @staticmethod + def _add_session(request): + """Attach a Django session to a RequestFactory request.""" + middleware = SessionMiddleware(lambda req: None) + middleware.process_request(request) + request.session.save() + return request + @mock.patch('common.djangoapps.third_party_auth.saml.SAMLAuth.auth_complete') def test_saml_auth_complete(self, super_auth_complete): super_auth_complete.side_effect = MultiValueDictKeyError('RelayState') @@ -48,3 +58,49 @@ def test_saml_auth_complete(self, super_auth_complete): backend.auth_complete() assert cm.exception.parameter == 'RelayState' + + @mock.patch('common.djangoapps.third_party_auth.saml.get_current_request') + @mock.patch('common.djangoapps.third_party_auth.saml.SAMLAuth.auth_complete') + def test_relaystate_splits_and_sets_next_when_safe(self, super_auth_complete, get_current_request_mock): + """RelayState may include both the IdP slug and a safe `next` destination.""" + rf = RequestFactory() + request = rf.post( + '/auth/complete/tpa-saml/', + data={ + 'SAMLResponse': 'ignored', + 'RelayState': 'example-idp|/courses/course-v1:edX+DemoX+Demo_Course/course/', + }, + HTTP_HOST=self.hostname, + ) + self._add_session(request) + get_current_request_mock.return_value = request + + super_auth_complete.return_value = 'ok' + backend = SAMLAuthBackend() + assert backend.auth_complete() == 'ok' + + assert request.POST.get('RelayState') == 'example-idp' + assert request.session.get('next') == '/courses/course-v1:edX+DemoX+Demo_Course/course/' + + @mock.patch('common.djangoapps.third_party_auth.saml.get_current_request') + @mock.patch('common.djangoapps.third_party_auth.saml.SAMLAuth.auth_complete') + def test_relaystate_drops_unsafe_next(self, super_auth_complete, get_current_request_mock): + """If RelayState contains an unsafe `next`, it is ignored but the slug is preserved.""" + rf = RequestFactory() + request = rf.post( + '/auth/complete/tpa-saml/', + data={ + 'SAMLResponse': 'ignored', + 'RelayState': 'example-idp|https%3A%2F%2Fevil.example.com%2Fpwn', + }, + HTTP_HOST=self.hostname, + ) + self._add_session(request) + get_current_request_mock.return_value = request + + super_auth_complete.return_value = 'ok' + backend = SAMLAuthBackend() + assert backend.auth_complete() == 'ok' + + assert request.POST.get('RelayState') == 'example-idp' + assert request.session.get('next') is None diff --git a/common/djangoapps/third_party_auth/toggles.py b/common/djangoapps/third_party_auth/toggles.py index d8f77f0b1cf2..51ba5fd0cdb5 100644 --- a/common/djangoapps/third_party_auth/toggles.py +++ b/common/djangoapps/third_party_auth/toggles.py @@ -43,3 +43,55 @@ def is_apple_user_migration_enabled(): Returns a boolean if Apple users migration is in process. """ return APPLE_USER_MIGRATION_FLAG.is_enabled() + + +# .. toggle_name: third_party_auth.tpa_next_url_on_dispatch +# .. toggle_implementation: WaffleFlag +# .. toggle_default: False +# .. toggle_description: When enabled, the third-party auth pipeline will forward +# session['next'] as a ?next= query parameter when redirecting to the login or +# registration page. This ensures the post-auth destination is preserved for new +# users who must complete registration before being redirected. +# .. toggle_use_cases: temporary +# .. toggle_creation_date: 2026-02-13 +# .. toggle_target_removal_date: 2026-06-01 +# .. toggle_warning: None. +TPA_NEXT_URL_ON_DISPATCH_FLAG = WaffleFlag(f'{THIRD_PARTY_AUTH_NAMESPACE}.tpa_next_url_on_dispatch', __name__) + + +def is_tpa_next_url_on_dispatch_enabled(): + """ + Returns True if the pipeline should forward session['next'] as a query parameter + when dispatching to login/register pages. + """ + return TPA_NEXT_URL_ON_DISPATCH_FLAG.is_enabled() + + +# .. toggle_name: third_party_auth.saml_provider_site_fallback +# .. toggle_implementation: WaffleFlag +# .. toggle_default: False +# .. toggle_description: When enabled, Registry.get_from_pipeline() will fall back to a +# site-independent SAMLProviderConfig lookup when the site-filtered registry returns no +# match for a running SAML pipeline. This handles cases where the SAMLProviderConfig or +# SAMLConfiguration is associated with a different Django site than the one currently +# serving the request, while SAML auth itself already completed (SAMLAuthBackend.get_idp() +# has no site check). Without this flag, pipeline steps such as should_force_account_creation() +# cannot read provider flags (e.g. send_to_registration_first), causing new users to land on +# the login page instead of registration. +# .. toggle_use_cases: temporary +# .. toggle_creation_date: 2026-02-19 +# .. toggle_target_removal_date: 2026-06-01 +# .. toggle_warning: The underlying site configuration mismatch should still be fixed in Django +# admin (SAMLConfiguration and SAMLProviderConfig must reference the correct site). This flag +# is a temporary workaround until that is resolved. +SAML_PROVIDER_SITE_FALLBACK_FLAG = WaffleFlag( + f'{THIRD_PARTY_AUTH_NAMESPACE}.saml_provider_site_fallback', __name__ +) + + +def is_saml_provider_site_fallback_enabled(): + """ + Returns True if get_from_pipeline() should fall back to a site-independent + SAMLProviderConfig lookup when the site-filtered registry finds no match. + """ + return SAML_PROVIDER_SITE_FALLBACK_FLAG.is_enabled() diff --git a/common/static/data/geoip/GeoLite2-Country.mmdb b/common/static/data/geoip/GeoLite2-Country.mmdb index 5a7ccaaf8748..f51bdd7c336e 100644 Binary files a/common/static/data/geoip/GeoLite2-Country.mmdb and b/common/static/data/geoip/GeoLite2-Country.mmdb differ diff --git a/common/static/js/vendor/pdfjs/viewer.js b/common/static/js/vendor/pdfjs/viewer.js index bfa90d05a782..e4e4009d3ce0 100644 --- a/common/static/js/vendor/pdfjs/viewer.js +++ b/common/static/js/vendor/pdfjs/viewer.js @@ -27,7 +27,7 @@ 'use strict'; -var DEFAULT_URL = 'compressed.tracemonkey-pldi-09.pdf'; +var DEFAULT_URL = ''; var DEFAULT_SCALE_DELTA = 1.1; var MIN_SCALE = 0.25; var MAX_SCALE = 10.0; diff --git a/lms/djangoapps/branding/tests/test_views.py b/lms/djangoapps/branding/tests/test_views.py index 36ebcd73509e..10c27192d11a 100644 --- a/lms/djangoapps/branding/tests/test_views.py +++ b/lms/djangoapps/branding/tests/test_views.py @@ -269,6 +269,20 @@ def test_index_does_not_redirect_without_site_override(self): response = self.client.get(reverse("root")) assert response.status_code == 200 + @override_settings(ENABLE_MKTG_SITE=True) + @override_settings(MKTG_URLS={'ROOT': 'https://foo.bar/'}) + @override_settings(LMS_ROOT_URL='https://foo.bar/') + def test_index_wont_redirect_to_marketing_root_if_it_matches_lms_root(self): + response = self.client.get(reverse("root")) + assert response.status_code == 200 + + @override_settings(ENABLE_MKTG_SITE=True) + @override_settings(MKTG_URLS={'ROOT': 'https://home.foo.bar/'}) + @override_settings(LMS_ROOT_URL='https://foo.bar/') + def test_index_will_redirect_to_new_root_if_mktg_site_is_enabled(self): + response = self.client.get(reverse("root")) + assert response.status_code == 302 + def test_index_redirects_to_marketing_site_with_site_override(self): """ Test index view redirects if MKTG_URLS['ROOT'] is set in SiteConfiguration """ self.use_site(self.site_other) diff --git a/lms/djangoapps/branding/views.py b/lms/djangoapps/branding/views.py index 711adb85afec..33c5813f16ff 100644 --- a/lms/djangoapps/branding/views.py +++ b/lms/djangoapps/branding/views.py @@ -42,7 +42,7 @@ def index(request): # page to make it easier to browse for courses (and register) if configuration_helpers.get_value( 'ALWAYS_REDIRECT_HOMEPAGE_TO_DASHBOARD_FOR_AUTHENTICATED_USER', - settings.FEATURES.get('ALWAYS_REDIRECT_HOMEPAGE_TO_DASHBOARD_FOR_AUTHENTICATED_USER', True)): + getattr(settings, 'ALWAYS_REDIRECT_HOMEPAGE_TO_DASHBOARD_FOR_AUTHENTICATED_USER', True)): return redirect('dashboard') if use_catalog_mfe(): @@ -50,7 +50,7 @@ def index(request): enable_mktg_site = configuration_helpers.get_value( 'ENABLE_MKTG_SITE', - settings.FEATURES.get('ENABLE_MKTG_SITE', False) + getattr(settings, 'ENABLE_MKTG_SITE', False) ) if enable_mktg_site: @@ -58,7 +58,9 @@ def index(request): 'MKTG_URLS', settings.MKTG_URLS ) - return redirect(marketing_urls.get('ROOT')) + root_url = marketing_urls.get("ROOT") + if root_url != getattr(settings, "LMS_ROOT_URL", None): + return redirect(root_url) domain = request.headers.get('Host') diff --git a/lms/djangoapps/bulk_email/signals.py b/lms/djangoapps/bulk_email/signals.py index 7402dca75482..da18a459aeaa 100644 --- a/lms/djangoapps/bulk_email/signals.py +++ b/lms/djangoapps/bulk_email/signals.py @@ -7,6 +7,7 @@ from eventtracking import tracker from common.djangoapps.student.models import CourseEnrollment +from openedx.core.djangoapps.content.course_overviews.models import CourseOverview from openedx.core.djangoapps.user_api.accounts.signals import USER_RETIRE_MAILINGS from edx_ace.signals import ACE_MESSAGE_SENT @@ -27,7 +28,14 @@ def force_optout_all(sender, **kwargs): # lint-amnesty, pylint: disable=unused- raise TypeError('Expected a User type, but received None.') for enrollment in CourseEnrollment.objects.filter(user=user): - Optout.objects.get_or_create(user=user, course_id=enrollment.course.id) + try: + Optout.objects.get_or_create(user=user, course_id=enrollment.course.id) + except CourseOverview.DoesNotExist: + log.warning( + f"CourseOverview not found for enrollment {enrollment.id} (user: {user.id}), " + f"skipping optout creation. This may mean the course was deleted." + ) + continue @receiver(ACE_MESSAGE_SENT) diff --git a/lms/djangoapps/bulk_email/tests/test_signals.py b/lms/djangoapps/bulk_email/tests/test_signals.py index 1a3715284b12..01ad9312da4c 100644 --- a/lms/djangoapps/bulk_email/tests/test_signals.py +++ b/lms/djangoapps/bulk_email/tests/test_signals.py @@ -10,9 +10,11 @@ from django.core.management import call_command from django.urls import reverse +from common.djangoapps.student.models import CourseEnrollment from common.djangoapps.student.tests.factories import AdminFactory, CourseEnrollmentFactory, UserFactory from lms.djangoapps.bulk_email.models import BulkEmailFlag, Optout from lms.djangoapps.bulk_email.signals import force_optout_all +from opaque_keys.edx.keys import CourseKey from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase # lint-amnesty, pylint: disable=wrong-import-order from xmodule.modulestore.tests.factories import CourseFactory # lint-amnesty, pylint: disable=wrong-import-order @@ -85,3 +87,41 @@ def test_optout_course(self): assert len(mail.outbox) == 1 assert len(mail.outbox[0].to) == 1 assert mail.outbox[0].to[0] == self.instructor.email + + @patch('lms.djangoapps.bulk_email.signals.log.warning') + def test_optout_handles_missing_course_overview(self, mock_log_warning): + """ + Test that force_optout_all gracefully handles CourseEnrollments + with missing CourseOverview records + """ + # Create a course key for a course that doesn't exist in CourseOverview + nonexistent_course_key = CourseKey.from_string('course-v1:TestX+Missing+2023') + + # Create an enrollment with a course_id that doesn't have a CourseOverview + CourseEnrollment.objects.create( + user=self.student, + course_id=nonexistent_course_key, + mode='honor' + ) + + # Verify the orphaned enrollment exists + assert CourseEnrollment.objects.filter( + user=self.student, + course_id=nonexistent_course_key + ).exists() + + force_optout_all(sender=self.__class__, user=self.student) + + # Verify that a warning was logged for the missing CourseOverview + mock_log_warning.assert_called() + call_args = mock_log_warning.call_args[0][0] + assert "CourseOverview not found for enrollment" in call_args + assert f"user: {self.student.id}" in call_args + assert "skipping optout creation" in call_args + + # Verify that optouts were created for valid courses only + valid_course_optouts = Optout.objects.filter(user=self.student, course_id=self.course.id) + missing_course_optouts = Optout.objects.filter(user=self.student, course_id=nonexistent_course_key) + + assert valid_course_optouts.count() == 1 + assert missing_course_optouts.count() == 0 diff --git a/lms/djangoapps/course_home_api/outline/serializers.py b/lms/djangoapps/course_home_api/outline/serializers.py index cfa518138a95..f66012362327 100644 --- a/lms/djangoapps/course_home_api/outline/serializers.py +++ b/lms/djangoapps/course_home_api/outline/serializers.py @@ -62,6 +62,7 @@ def get_blocks(self, block): # pylint: disable=missing-function-docstring 'type': block_type, 'has_scheduled_content': block.get('has_scheduled_content'), 'hide_from_toc': block.get('hide_from_toc'), + 'is_preview': block.get('is_preview', False), }, } if 'special_exam_info' in self.context.get('extra_fields', []) and block.get('special_exam_info'): diff --git a/lms/djangoapps/course_home_api/outline/tests/test_view.py b/lms/djangoapps/course_home_api/outline/tests/test_view.py index 74e22e5fcc4b..6de5db83f94c 100644 --- a/lms/djangoapps/course_home_api/outline/tests/test_view.py +++ b/lms/djangoapps/course_home_api/outline/tests/test_view.py @@ -13,6 +13,7 @@ from django.test import override_settings from django.urls import reverse from edx_toggles.toggles.testutils import override_waffle_flag +from opaque_keys.edx.keys import UsageKey from cms.djangoapps.contentstore.outlines import update_outline_from_modulestore from common.djangoapps.course_modes.models import CourseMode @@ -43,6 +44,7 @@ BlockFactory, CourseFactory ) +from xmodule.partitions.partitions import ENROLLMENT_TRACK_PARTITION_ID @ddt.ddt @@ -484,6 +486,89 @@ def test_course_progress_analytics_disabled(self, mock_task): self.client.get(self.url) mock_task.assert_not_called() + # Tests for verified content preview functionality + # These tests cover the feature that allows audit learners to preview + # the structure of verified-only content without access to the content itself + + @patch('lms.djangoapps.course_home_api.outline.views.learner_can_preview_verified_content') + def test_verified_content_preview_disabled_integration(self, mock_preview_function): + """Test that when verified preview is disabled, no preview markers are added.""" + # Given a course with some Verified only sequences + with self.store.bulk_operations(self.course.id): + chapter = BlockFactory.create(category='chapter', parent_location=self.course.location) + sequential = BlockFactory.create( + category='sequential', + parent_location=chapter.location, + display_name='Verified Sequential', + group_access={ENROLLMENT_TRACK_PARTITION_ID: [2]} # restrict to verified only + ) + update_outline_from_modulestore(self.course.id) + + # ... where the preview feature is disabled + mock_preview_function.return_value = False + + # When I access them as an audit user + CourseEnrollment.enroll(self.user, self.course.id, CourseMode.AUDIT) + response = self.client.get(self.url) + + # Then I get a valid response back + assert response.status_code == 200 + + # ... with course_blocks populated + course_blocks = response.data['course_blocks']["blocks"] + + # ... but with verified content omitted + assert str(sequential.location) not in course_blocks + + # ... and no block has preview set to true + for block in course_blocks: + assert course_blocks[block].get('is_preview') is not True + + @patch('lms.djangoapps.course_home_api.outline.views.learner_can_preview_verified_content') + @patch('lms.djangoapps.course_home_api.outline.views.get_user_course_outline') + def test_verified_content_preview_enabled_marks_previewable_content(self, mock_outline, mock_preview_enabled): + """Test that when verified preview is enabled, previewable sequences and chapters are marked.""" + # Given a course with some Verified only sequences and some regular sequences + with self.store.bulk_operations(self.course.id): + chapter = BlockFactory.create(category='chapter', parent_location=self.course.location) + verified_sequential = BlockFactory.create( + category='sequential', + parent_location=chapter.location, + display_name='Verified Sequential', + ) + regular_sequential = BlockFactory.create( + category='sequential', + parent_location=chapter.location, + display_name='Regular Sequential' + ) + update_outline_from_modulestore(self.course.id) + + # ... with an outline that correctly identifies previewable sequences + mock_course_outline = Mock() + mock_course_outline.sections = {Mock(usage_key=chapter.location)} + mock_course_outline.sequences = {verified_sequential.location, regular_sequential.location} + mock_course_outline.previewable_sequences = {verified_sequential.location} + mock_outline.return_value = mock_course_outline + + # When I access them as an audit user with preview enabled + CourseEnrollment.enroll(self.user, self.course.id, CourseMode.AUDIT) + mock_preview_enabled.return_value = True + + # Then I get a valid response back + response = self.client.get(self.url) + assert response.status_code == 200 + + # ... with course_blocks populated + course_blocks = response.data['course_blocks']["blocks"] + + for block in course_blocks: + # ... and the verified only content is marked as preview only + if UsageKey.from_string(block) in mock_course_outline.previewable_sequences: + assert course_blocks[block].get('is_preview') is True + # ... and the regular content is not marked as preview + else: + assert course_blocks[block].get('is_preview') is False + @ddt.ddt class SidebarBlocksTestViews(BaseCourseHomeTests): diff --git a/lms/djangoapps/course_home_api/outline/views.py b/lms/djangoapps/course_home_api/outline/views.py index 7c5307cba764..78d5767ffeed 100644 --- a/lms/djangoapps/course_home_api/outline/views.py +++ b/lms/djangoapps/course_home_api/outline/views.py @@ -36,7 +36,10 @@ ) from lms.djangoapps.course_home_api.utils import get_course_or_403 from lms.djangoapps.course_home_api.tasks import collect_progress_for_user_in_course -from lms.djangoapps.course_home_api.toggles import send_course_progress_analytics_for_student_is_enabled +from lms.djangoapps.course_home_api.toggles import ( + learner_can_preview_verified_content, + send_course_progress_analytics_for_student_is_enabled, +) from lms.djangoapps.courseware.access import has_access from lms.djangoapps.courseware.context_processor import user_timezone_locale_prefs from lms.djangoapps.courseware.courses import get_course_date_blocks, get_course_info_section @@ -209,6 +212,7 @@ def get(self, request, *args, **kwargs): # pylint: disable=too-many-statements allow_anonymous = COURSE_ENABLE_UNENROLLED_ACCESS_FLAG.is_enabled(course_key) allow_public = allow_anonymous and course.course_visibility == COURSE_VISIBILITY_PUBLIC allow_public_outline = allow_anonymous and course.course_visibility == COURSE_VISIBILITY_PUBLIC_OUTLINE + allow_preview_of_verified_content = learner_can_preview_verified_content(course_key, request.user) # User locale settings user_timezone_locale = user_timezone_locale_prefs(request) @@ -309,7 +313,8 @@ def get(self, request, *args, **kwargs): # pylint: disable=too-many-statements # so this is a tiny first step in that migration. if course_blocks: user_course_outline = get_user_course_outline( - course_key, request.user, datetime.now(tz=timezone.utc) + course_key, request.user, datetime.now(tz=timezone.utc), + preview_verified_content=allow_preview_of_verified_content ) available_seq_ids = {str(usage_key) for usage_key in user_course_outline.sequences} @@ -339,6 +344,19 @@ def get(self, request, *args, **kwargs): # pylint: disable=too-many-statements ) ] if 'children' in chapter_data else [] + # For audit preview of verified content, we don't remove verified content. + # Instead, we mark it as preview so the frontend can handle it appropriately. + if allow_preview_of_verified_content: + previewable_sequences = {str(usage_key) for usage_key in user_course_outline.previewable_sequences} + + # Iterate through course_blocks to mark previewable sequences and chapters + for chapter_data in course_blocks['children']: + if chapter_data['id'] in previewable_sequences: + chapter_data['is_preview'] = True + for seq_data in chapter_data.get('children', []): + if seq_data['id'] in previewable_sequences: + seq_data['is_preview'] = True + user_has_passing_grade = False if not request.user.is_anonymous: user_grade = CourseGradeFactory().read(request.user, course) diff --git a/lms/djangoapps/course_home_api/progress/api.py b/lms/djangoapps/course_home_api/progress/api.py index b2a8634c59f4..f89ecd3d2596 100644 --- a/lms/djangoapps/course_home_api/progress/api.py +++ b/lms/djangoapps/course_home_api/progress/api.py @@ -2,14 +2,226 @@ Python APIs exposed for the progress tracking functionality of the course home API. """ +from __future__ import annotations + from django.contrib.auth import get_user_model from opaque_keys.edx.keys import CourseKey +from openedx.core.lib.grade_utils import round_away_from_zero +from xmodule.graders import ShowCorrectness +from datetime import datetime, timezone from lms.djangoapps.courseware.courses import get_course_blocks_completion_summary +from dataclasses import dataclass, field User = get_user_model() +@dataclass +class _AssignmentBucket: + """Holds scores and visibility info for one assignment type. + + Attributes: + assignment_type: Full assignment type name from the grading policy (for example, "Homework"). + num_total: The total number of assignments expected to contribute to the grade before any + drop-lowest rules are applied. + last_grade_publish_date: The most recent date when grades for all assignments of assignment_type + are released and included in the final grade. + scores: Per-subsection fractional scores (each value is ``earned / possible`` and falls in + the range 0–1). While awaiting published content we pad the list with zero placeholders + so that its length always matches ``num_total`` until real scores replace them. + visibilities: Mirrors ``scores`` index-for-index and records whether each subsection's + correctness feedback is visible to the learner (``True``), hidden (``False``), or not + yet populated (``None`` when the entry is a placeholder). + included: Tracks whether each subsection currently counts toward the learner's grade as + determined by ``SubsectionGrade.show_grades``. Values follow the same convention as + ``visibilities`` (``True`` / ``False`` / ``None`` placeholders). + assignments_created: Count of real subsections inserted into the bucket so far. Once this + reaches ``num_total``, all placeholder entries have been replaced with actual data. + """ + assignment_type: str + num_total: int + last_grade_publish_date: datetime + scores: list[float] = field(default_factory=list) + visibilities: list[bool | None] = field(default_factory=list) + included: list[bool | None] = field(default_factory=list) + assignments_created: int = 0 + + @classmethod + def with_placeholders(cls, assignment_type: str, num_total: int, now: datetime): + """Create a bucket prefilled with placeholder (empty) entries.""" + return cls( + assignment_type=assignment_type, + num_total=num_total, + last_grade_publish_date=now, + scores=[0] * num_total, + visibilities=[None] * num_total, + included=[None] * num_total, + ) + + def add_subsection(self, score: float, is_visible: bool, is_included: bool): + """Add a subsection’s score and visibility, replacing a placeholder if space remains.""" + if self.assignments_created < self.num_total: + if self.scores: + self.scores.pop(0) + if self.visibilities: + self.visibilities.pop(0) + if self.included: + self.included.pop(0) + self.scores.append(score) + self.visibilities.append(is_visible) + self.included.append(is_included) + self.assignments_created += 1 + + def drop_lowest(self, num_droppable: int): + """Remove the lowest scoring subsections, up to the provided num_droppable.""" + while num_droppable > 0 and self.scores: + idx = self.scores.index(min(self.scores)) + self.scores.pop(idx) + self.visibilities.pop(idx) + self.included.pop(idx) + num_droppable -= 1 + + def hidden_state(self) -> str: + """Return whether kept scores are all, some, or none hidden.""" + if not self.visibilities: + return 'none' + all_hidden = all(v is False for v in self.visibilities) + some_hidden = any(v is False for v in self.visibilities) + if all_hidden: + return 'all' + if some_hidden: + return 'some' + return 'none' + + def averages(self) -> tuple[float, float]: + """Compute visible and included averages over kept scores. + + Visible average uses only grades with visibility flag True in numerator; denominator is total + number of kept scores (mirrors legacy behavior). Included average uses only scores that are + marked included (show_grades True) in numerator with same denominator. + + Returns: + (earned_visible, earned_all) tuple of floats (0-1 each). + """ + if not self.scores: + return 0.0, 0.0 + visible_scores = [s for i, s in enumerate(self.scores) if self.visibilities[i]] + included_scores = [s for i, s in enumerate(self.scores) if self.included[i]] + earned_visible = (sum(visible_scores) / len(self.scores)) if self.scores else 0.0 + earned_all = (sum(included_scores) / len(self.scores)) if self.scores else 0.0 + return earned_visible, earned_all + + +class _AssignmentTypeGradeAggregator: + """Collects and aggregates subsection grades by assignment type.""" + + def __init__(self, course_grade, grading_policy: dict, has_staff_access: bool): + """Initialize with course grades, grading policy, and staff access flag.""" + self.course_grade = course_grade + self.grading_policy = grading_policy + self.has_staff_access = has_staff_access + self.now = datetime.now(timezone.utc) + self.policy_map = self._build_policy_map() + self.buckets: dict[str, _AssignmentBucket] = {} + + def _build_policy_map(self) -> dict: + """Convert grading policy into a lookup of assignment type → policy info.""" + policy_map = {} + for policy in self.grading_policy.get('GRADER', []): + policy_map[policy.get('type')] = { + 'weight': policy.get('weight', 0.0), + 'short_label': policy.get('short_label', ''), + 'num_droppable': policy.get('drop_count', 0), + 'num_total': policy.get('min_count', 0), + } + return policy_map + + def _bucket_for(self, assignment_type: str) -> _AssignmentBucket: + """Get or create a score bucket for the given assignment type.""" + bucket = self.buckets.get(assignment_type) + if bucket is None: + num_total = self.policy_map.get(assignment_type, {}).get('num_total', 0) or 0 + bucket = _AssignmentBucket.with_placeholders(assignment_type, num_total, self.now) + self.buckets[assignment_type] = bucket + return bucket + + def collect(self): + """Gather subsection grades into their respective assignment buckets.""" + for chapter in self.course_grade.chapter_grades.values(): + for subsection_grade in chapter.get('sections', []): + if not getattr(subsection_grade, 'graded', False): + continue + assignment_type = getattr(subsection_grade, 'format', '') or '' + if not assignment_type: + continue + graded_total = getattr(subsection_grade, 'graded_total', None) + earned = getattr(graded_total, 'earned', 0.0) if graded_total else 0.0 + possible = getattr(graded_total, 'possible', 0.0) if graded_total else 0.0 + earned = 0.0 if earned is None else earned + possible = 0.0 if possible is None else possible + score = (earned / possible) if possible else 0.0 + is_visible = ShowCorrectness.correctness_available( + subsection_grade.show_correctness, subsection_grade.due, self.has_staff_access + ) + is_included = subsection_grade.show_grades(self.has_staff_access) + bucket = self._bucket_for(assignment_type) + bucket.add_subsection(score, is_visible, is_included) + visibilities_with_due_dates = [ShowCorrectness.PAST_DUE, ShowCorrectness.NEVER_BUT_INCLUDE_GRADE] + if subsection_grade.show_correctness in visibilities_with_due_dates: + if subsection_grade.due and subsection_grade.due > bucket.last_grade_publish_date: + bucket.last_grade_publish_date = subsection_grade.due + + def build_results(self) -> dict: + """Apply drops, compute averages, and return aggregated results and total grade.""" + final_grades = 0.0 + rows = [] + for assignment_type, bucket in self.buckets.items(): + policy = self.policy_map.get(assignment_type, {}) + bucket.drop_lowest(policy.get('num_droppable', 0)) + earned_visible, earned_all = bucket.averages() + weight = policy.get('weight', 0.0) + short_label = policy.get('short_label', '') + row = { + 'type': assignment_type, + 'weight': weight, + 'average_grade': round_away_from_zero(earned_visible, 4), + 'weighted_grade': round_away_from_zero(earned_visible * weight, 4), + 'short_label': short_label, + 'num_droppable': policy.get('num_droppable', 0), + 'last_grade_publish_date': bucket.last_grade_publish_date, + 'has_hidden_contribution': bucket.hidden_state(), + } + final_grades += earned_all * weight + rows.append(row) + rows.sort(key=lambda r: r['weight']) + return {'results': rows, 'final_grades': round_away_from_zero(final_grades, 4)} + + def run(self) -> dict: + """Execute full pipeline (collect + aggregate) returning final payload.""" + self.collect() + return self.build_results() + + +def aggregate_assignment_type_grade_summary( + course_grade, + grading_policy: dict, + has_staff_access: bool = False, +) -> dict: + """ + Aggregate subsection grades by assignment type and return summary data. + Args: + course_grade: CourseGrade object containing chapter and subsection grades. + grading_policy: Dictionary representing the course's grading policy. + has_staff_access: Boolean indicating if the user has staff access to view all grades. + Returns: + Dictionary with keys: + results: list of per-assignment-type summary dicts + final_grades: overall weighted contribution (float, 4 decimal rounding) + """ + aggregator = _AssignmentTypeGradeAggregator(course_grade, grading_policy, has_staff_access) + return aggregator.run() + + def calculate_progress_for_learner_in_course(course_key: CourseKey, user: User) -> dict: """ Calculate a given learner's progress in the specified course run. diff --git a/lms/djangoapps/course_home_api/progress/serializers.py b/lms/djangoapps/course_home_api/progress/serializers.py index 6bdc204434af..c48660a41c6a 100644 --- a/lms/djangoapps/course_home_api/progress/serializers.py +++ b/lms/djangoapps/course_home_api/progress/serializers.py @@ -26,6 +26,7 @@ class SubsectionScoresSerializer(ReadOnlySerializer): assignment_type = serializers.CharField(source='format') block_key = serializers.SerializerMethodField() display_name = serializers.CharField() + due = serializers.DateTimeField(allow_null=True) has_graded_assignment = serializers.BooleanField(source='graded') override = serializers.SerializerMethodField() learner_has_access = serializers.SerializerMethodField() @@ -127,6 +128,20 @@ class VerificationDataSerializer(ReadOnlySerializer): status_date = serializers.DateTimeField() +class AssignmentTypeScoresSerializer(ReadOnlySerializer): + """ + Serializer for aggregated scores per assignment type. + """ + type = serializers.CharField() + weight = serializers.FloatField() + average_grade = serializers.FloatField() + weighted_grade = serializers.FloatField() + last_grade_publish_date = serializers.DateTimeField() + has_hidden_contribution = serializers.CharField() + short_label = serializers.CharField() + num_droppable = serializers.IntegerField() + + class ProgressTabSerializer(VerifiedModeSerializer): """ Serializer for progress tab @@ -146,3 +161,5 @@ class ProgressTabSerializer(VerifiedModeSerializer): user_has_passing_grade = serializers.BooleanField() verification_data = VerificationDataSerializer() disable_progress_graph = serializers.BooleanField() + assignment_type_grade_summary = AssignmentTypeScoresSerializer(many=True) + final_grades = serializers.FloatField() diff --git a/lms/djangoapps/course_home_api/progress/tests/test_api.py b/lms/djangoapps/course_home_api/progress/tests/test_api.py index 30d8d9059eaa..51e7dd68286e 100644 --- a/lms/djangoapps/course_home_api/progress/tests/test_api.py +++ b/lms/djangoapps/course_home_api/progress/tests/test_api.py @@ -6,7 +6,80 @@ from django.test import TestCase -from lms.djangoapps.course_home_api.progress.api import calculate_progress_for_learner_in_course +from lms.djangoapps.course_home_api.progress.api import ( + calculate_progress_for_learner_in_course, + aggregate_assignment_type_grade_summary, +) +from xmodule.graders import ShowCorrectness +from datetime import datetime, timedelta, timezone +from types import SimpleNamespace + + +def _make_subsection(fmt, earned, possible, show_corr, *, due_delta_days=None): + """Build a lightweight subsection object for testing aggregation scenarios.""" + graded_total = SimpleNamespace(earned=earned, possible=possible) + due = None + if due_delta_days is not None: + due = datetime.now(timezone.utc) + timedelta(days=due_delta_days) + return SimpleNamespace( + graded=True, + format=fmt, + graded_total=graded_total, + show_correctness=show_corr, + due=due, + show_grades=lambda staff: True, + ) + + +_AGGREGATION_SCENARIOS = [ + ( + 'all_visible_always', + {'type': 'Homework', 'weight': 1.0, 'drop_count': 0, 'min_count': 2, 'short_label': 'HW'}, + [ + _make_subsection('Homework', 1, 1, ShowCorrectness.ALWAYS), + _make_subsection('Homework', 0.5, 1, ShowCorrectness.ALWAYS), + ], + {'avg': 0.75, 'weighted': 0.75, 'hidden': 'none', 'final': 0.75}, + ), + ( + 'some_hidden_never_but_include', + {'type': 'Exam', 'weight': 1.0, 'drop_count': 0, 'min_count': 2, 'short_label': 'EX'}, + [ + _make_subsection('Exam', 1, 1, ShowCorrectness.ALWAYS), + _make_subsection('Exam', 0.5, 1, ShowCorrectness.NEVER_BUT_INCLUDE_GRADE), + ], + {'avg': 0.5, 'weighted': 0.5, 'hidden': 'some', 'final': 0.75}, + ), + ( + 'all_hidden_never_but_include', + {'type': 'Quiz', 'weight': 1.0, 'drop_count': 0, 'min_count': 2, 'short_label': 'QZ'}, + [ + _make_subsection('Quiz', 0.4, 1, ShowCorrectness.NEVER_BUT_INCLUDE_GRADE), + _make_subsection('Quiz', 0.6, 1, ShowCorrectness.NEVER_BUT_INCLUDE_GRADE), + ], + {'avg': 0.0, 'weighted': 0.0, 'hidden': 'all', 'final': 0.5}, + ), + ( + 'past_due_mixed_visibility', + {'type': 'Lab', 'weight': 1.0, 'drop_count': 0, 'min_count': 2, 'short_label': 'LB'}, + [ + _make_subsection('Lab', 0.8, 1, ShowCorrectness.PAST_DUE, due_delta_days=-1), + _make_subsection('Lab', 0.2, 1, ShowCorrectness.PAST_DUE, due_delta_days=+3), + ], + {'avg': 0.4, 'weighted': 0.4, 'hidden': 'some', 'final': 0.5}, + ), + ( + 'drop_lowest_keeps_high_scores', + {'type': 'Project', 'weight': 1.0, 'drop_count': 2, 'min_count': 4, 'short_label': 'PR'}, + [ + _make_subsection('Project', 1, 1, ShowCorrectness.ALWAYS), + _make_subsection('Project', 1, 1, ShowCorrectness.ALWAYS), + _make_subsection('Project', 0, 1, ShowCorrectness.ALWAYS), + _make_subsection('Project', 0, 1, ShowCorrectness.ALWAYS), + ], + {'avg': 1.0, 'weighted': 1.0, 'hidden': 'none', 'final': 1.0}, + ), +] class ProgressApiTests(TestCase): @@ -73,3 +146,37 @@ def test_calculate_progress_for_learner_in_course_summary_empty(self, mock_get_s results = calculate_progress_for_learner_in_course("some_course", "some_user") assert not results + + def test_aggregate_assignment_type_grade_summary_scenarios(self): + """ + A test to verify functionality of aggregate_assignment_type_grade_summary. + 1. Test visibility modes (always, never but include grade, past due) + 2. Test drop-lowest behavior + 3. Test weighting behavior + 4. Test final grade calculation + 5. Test average grade calculation + 6. Test weighted grade calculation + 7. Test has_hidden_contribution calculation + """ + + for case_name, policy, subsections, expected in _AGGREGATION_SCENARIOS: + with self.subTest(case_name=case_name): + course_grade = SimpleNamespace(chapter_grades={'chapter': {'sections': subsections}}) + grading_policy = {'GRADER': [policy]} + + result = aggregate_assignment_type_grade_summary( + course_grade, + grading_policy, + has_staff_access=False, + ) + + assert 'results' in result and 'final_grades' in result + assert result['final_grades'] == expected['final'] + assert len(result['results']) == 1 + + row = result['results'][0] + assert row['type'] == policy['type'], case_name + assert row['average_grade'] == expected['avg'] + assert row['weighted_grade'] == expected['weighted'] + assert row['has_hidden_contribution'] == expected['hidden'] + assert row['num_droppable'] == policy['drop_count'] diff --git a/lms/djangoapps/course_home_api/progress/tests/test_views.py b/lms/djangoapps/course_home_api/progress/tests/test_views.py index d13ebec29c21..8012e11675f1 100644 --- a/lms/djangoapps/course_home_api/progress/tests/test_views.py +++ b/lms/djangoapps/course_home_api/progress/tests/test_views.py @@ -282,8 +282,8 @@ def test_url_hidden_if_subsection_hide_after_due(self): assert hide_after_due_subsection['url'] is None @ddt.data( - (True, 0.7), # midterm and final are visible to staff - (False, 0.3), # just the midterm is visible to learners + (True, 0.72), # lab, midterm and final are visible to staff + (False, 0.32), # Only lab and midterm is visible to learners ) @ddt.unpack def test_course_grade_considers_subsection_grade_visibility(self, is_staff, expected_percent): @@ -301,14 +301,18 @@ def test_course_grade_considers_subsection_grade_visibility(self, is_staff, expe never = self.add_subsection_with_problem(format='Homework', show_correctness='never') always = self.add_subsection_with_problem(format='Midterm Exam', show_correctness='always') past_due = self.add_subsection_with_problem(format='Final Exam', show_correctness='past_due', due=tomorrow) + never_but_show_grade = self.add_subsection_with_problem( + format='Lab', show_correctness='never_but_include_grade' + ) answer_problem(self.course, get_mock_request(self.user), never) answer_problem(self.course, get_mock_request(self.user), always) answer_problem(self.course, get_mock_request(self.user), past_due) + answer_problem(self.course, get_mock_request(self.user), never_but_show_grade) # First, confirm the grade in the database - it should never change based on user state. # This is midterm and final and a single problem added together. - assert CourseGradeFactory().read(self.user, self.course).percent == 0.72 + assert CourseGradeFactory().read(self.user, self.course).percent == 0.73 response = self.client.get(self.url) assert response.status_code == 200 diff --git a/lms/djangoapps/course_home_api/progress/views.py b/lms/djangoapps/course_home_api/progress/views.py index 3783c19061dc..54e71df48cc5 100644 --- a/lms/djangoapps/course_home_api/progress/views.py +++ b/lms/djangoapps/course_home_api/progress/views.py @@ -13,8 +13,11 @@ from rest_framework.response import Response from xmodule.modulestore.django import modulestore +from xmodule.graders import ShowCorrectness from common.djangoapps.student.models import CourseEnrollment from lms.djangoapps.course_home_api.progress.serializers import ProgressTabSerializer +from lms.djangoapps.course_home_api.progress.api import aggregate_assignment_type_grade_summary + from lms.djangoapps.course_home_api.toggles import course_home_mfe_progress_tab_is_active from lms.djangoapps.courseware.access import has_access, has_ccx_coach_role from lms.djangoapps.course_blocks.api import get_course_blocks @@ -99,6 +102,7 @@ class ProgressTabView(RetrieveAPIView): assignment_type: (str) the format, if any, of the Subsection (Homework, Exam, etc) block_key: (str) the key of the given subsection block display_name: (str) a str of what the name of the Subsection is for displaying on the site + due: (str or None) the due date of the subsection in ISO 8601 format, or None if no due date is set has_graded_assignment: (bool) whether or not the Subsection is a graded assignment learner_has_access: (bool) whether the learner has access to the subsection (could be FBE gated) num_points_earned: (int) the amount of points the user has earned for the given subsection @@ -175,6 +179,18 @@ def _get_student_user(self, request, course_key, student_id, is_staff): except User.DoesNotExist as exc: raise Http404 from exc + def _visible_section_scores(self, course_grade): + """Return only those chapter/section scores that are visible to the learner.""" + visible_chapters = [] + for chapter in course_grade.chapter_grades.values(): + filtered_sections = [ + subsection + for subsection in chapter["sections"] + if getattr(subsection, "show_correctness", None) != ShowCorrectness.NEVER_BUT_INCLUDE_GRADE + ] + visible_chapters.append({**chapter, "sections": filtered_sections}) + return visible_chapters + def get(self, request, *args, **kwargs): course_key_string = kwargs.get('course_key_string') course_key = CourseKey.from_string(course_key_string) @@ -245,6 +261,16 @@ def get(self, request, *args, **kwargs): access_expiration = get_access_expiration_data(request.user, course_overview) + # Aggregations delegated to helper functions for reuse and testability + assignment_type_grade_summary = aggregate_assignment_type_grade_summary( + course_grade, + grading_policy, + has_staff_access=is_staff, + ) + + # Filter out section scores to only have those that are visible to the user + section_scores = self._visible_section_scores(course_grade) + data = { 'access_expiration': access_expiration, 'certificate_data': get_cert_data(student, course, enrollment_mode, course_grade), @@ -255,12 +281,14 @@ def get(self, request, *args, **kwargs): 'enrollment_mode': enrollment_mode, 'grading_policy': grading_policy, 'has_scheduled_content': has_scheduled_content, - 'section_scores': list(course_grade.chapter_grades.values()), + 'section_scores': section_scores, 'studio_url': get_studio_url(course, 'settings/grading'), 'username': username, 'user_has_passing_grade': user_has_passing_grade, 'verification_data': verification_data, 'disable_progress_graph': disable_progress_graph, + 'assignment_type_grade_summary': assignment_type_grade_summary["results"], + 'final_grades': assignment_type_grade_summary["final_grades"], } context = self.get_serializer_context() context['staff_access'] = is_staff diff --git a/lms/djangoapps/course_home_api/tests/test_toggles.py b/lms/djangoapps/course_home_api/tests/test_toggles.py new file mode 100644 index 000000000000..46ab545d0ade --- /dev/null +++ b/lms/djangoapps/course_home_api/tests/test_toggles.py @@ -0,0 +1,155 @@ +""" +Tests for Course Home API toggles. +""" + +from unittest.mock import Mock, patch + +from django.test import TestCase +from opaque_keys.edx.keys import CourseKey + +from common.djangoapps.course_modes.models import CourseMode +from common.djangoapps.course_modes.tests.factories import CourseModeFactory + +from ..toggles import learner_can_preview_verified_content + + +class TestLearnerCanPreviewVerifiedContent(TestCase): + """Test cases for learner_can_preview_verified_content function.""" + + def setUp(self): + """Set up test fixtures.""" + self.course_key = CourseKey.from_string("course-v1:TestX+CS101+2024") + self.user = Mock() + + # Set up patchers + self.feature_enabled_patcher = patch( + "lms.djangoapps.course_home_api.toggles.audit_learner_verified_preview_is_enabled" + ) + self.verified_mode_for_course_patcher = patch( + "common.djangoapps.course_modes.models.CourseMode.verified_mode_for_course" + ) + self.get_enrollment_patcher = patch( + "common.djangoapps.student.models.CourseEnrollment.get_enrollment" + ) + + # Course set up with verified, professional, and audit modes + self.verified_mode = CourseModeFactory( + course_id=self.course_key, + mode_slug=CourseMode.VERIFIED, + mode_display_name="Verified", + ) + self.professional_mode = CourseModeFactory( + course_id=self.course_key, + mode_slug=CourseMode.PROFESSIONAL, + mode_display_name="Professional", + ) + self.audit_mode = CourseModeFactory( + course_id=self.course_key, + mode_slug=CourseMode.AUDIT, + mode_display_name="Audit", + ) + self.course_modes_dict = { + "audit": self.audit_mode, + "verified": self.verified_mode, + "professional": self.professional_mode, + } + + # Start patchers + self.mock_feature_enabled = self.feature_enabled_patcher.start() + self.mock_verified_mode_for_course = ( + self.verified_mode_for_course_patcher.start() + ) + self.mock_get_enrollment = self.get_enrollment_patcher.start() + + def _enroll_user(self, mode): + """Helper method to set up user enrollment mock.""" + mock_enrollment = Mock() + mock_enrollment.mode = mode + self.mock_get_enrollment.return_value = mock_enrollment + + def tearDown(self): + """Clean up patchers.""" + self.feature_enabled_patcher.stop() + self.verified_mode_for_course_patcher.stop() + self.get_enrollment_patcher.stop() + + def test_all_conditions_met_returns_true(self): + """Test that function returns True when all conditions are met.""" + # Given the feature is enabled, course has verified mode, and user is enrolled as audit + self.mock_feature_enabled.return_value = True + self.mock_verified_mode_for_course.return_value = self.course_modes_dict[ + "professional" + ] + self._enroll_user(CourseMode.AUDIT) + + # When I check if the learner can preview verified content + result = learner_can_preview_verified_content(self.course_key, self.user) + + # Then the result should be True + self.assertTrue(result) + + def test_feature_disabled_returns_false(self): + """Test that function returns False when feature is disabled.""" + # Given the feature is disabled + self.mock_feature_enabled.return_value = False + + # ... even if all other conditions are met + self.mock_verified_mode_for_course.return_value = self.course_modes_dict[ + "professional" + ] + self._enroll_user(CourseMode.AUDIT) + + # When I check if the learner can preview verified content + result = learner_can_preview_verified_content(self.course_key, self.user) + + # Then the result should be False + self.assertFalse(result) + + def test_no_verified_mode_returns_false(self): + """Test that function returns False when course has no verified mode.""" + # Given the course does not have a verified mode + self.mock_verified_mode_for_course.return_value = None + + # ... even if all other conditions are met + self.mock_feature_enabled.return_value = True + self._enroll_user(CourseMode.AUDIT) + + # When I check if the learner can preview verified content + result = learner_can_preview_verified_content(self.course_key, self.user) + + # Then the result should be False + self.assertFalse(result) + + def test_no_enrollment_returns_false(self): + """Test that function returns False when user is not enrolled.""" + # Given the user is unenrolled + self.mock_get_enrollment.return_value = None + + # ... even if all other conditions are met + self.mock_feature_enabled.return_value = True + self.mock_verified_mode_for_course.return_value = self.course_modes_dict[ + "professional" + ] + + # When I check if the learner can preview verified content + result = learner_can_preview_verified_content(self.course_key, self.user) + + # Then the result should be False + self.assertFalse(result) + + def test_verified_enrollment_returns_false(self): + """Test that function returns False when user is enrolled in verified mode.""" + # Given the user is not enrolled as audit + self._enroll_user(CourseMode.VERIFIED) + + # ... even if all other conditions are met + self.mock_feature_enabled.return_value = True + self.mock_verified_mode_for_course.return_value = self.course_modes_dict[ + "professional" + ] + + # When I check if the learner can preview verified content + result = learner_can_preview_verified_content(self.course_key, self.user) + + # Then the result should be False + self.assertFalse(result) diff --git a/lms/djangoapps/course_home_api/toggles.py b/lms/djangoapps/course_home_api/toggles.py index 052862796c75..1f2d32b87e96 100644 --- a/lms/djangoapps/course_home_api/toggles.py +++ b/lms/djangoapps/course_home_api/toggles.py @@ -3,6 +3,9 @@ """ from openedx.core.djangoapps.waffle_utils import CourseWaffleFlag +from openedx.core.lib.cache_utils import request_cached +from common.djangoapps.course_modes.models import CourseMode +from common.djangoapps.student.models import CourseEnrollment WAFFLE_FLAG_NAMESPACE = 'course_home' @@ -51,6 +54,21 @@ ) +# Waffle flag to enable audit learner preview of course structure visible to verified learners. +# +# .. toggle_name: course_home.audit_learner_verified_preview +# .. toggle_implementation: CourseWaffleFlag +# .. toggle_default: False +# .. toggle_description: Where enabled, audit learners can see the presence of the sections / units +# otherwise restricted to verified learners. The content itself remains inaccessible. +# .. toggle_use_cases: open_edx +# .. toggle_creation_date: 2025-11-07 +# .. toggle_target_removal_date: None +COURSE_HOME_AUDIT_LEARNER_VERIFIED_PREVIEW = CourseWaffleFlag( + f'{WAFFLE_FLAG_NAMESPACE}.audit_learner_verified_preview', __name__ +) + + def course_home_mfe_progress_tab_is_active(course_key): # Avoiding a circular dependency from .models import DisableProgressPageStackedConfig @@ -73,3 +91,40 @@ def send_course_progress_analytics_for_student_is_enabled(course_key): Returns True if the course completion analytics feature is enabled for a given course. """ return COURSE_HOME_SEND_COURSE_PROGRESS_ANALYTICS_FOR_STUDENT.is_enabled(course_key) + + +def audit_learner_verified_preview_is_enabled(course_key): + """ + Returns True if the audit learner verified preview feature is enabled for a given course. + """ + return COURSE_HOME_AUDIT_LEARNER_VERIFIED_PREVIEW.is_enabled(course_key) + + +@request_cached() +def learner_can_preview_verified_content(course_key, user): + """ + Determine if an audit learner can preview verified content in a course. + + Args: + course_key: The course identifier. + user: The user object + Returns: + True if the learner can preview verified content, False otherwise. + """ + # To preview verified content, the feature must be enabled for the course... + feature_enabled = audit_learner_verified_preview_is_enabled(course_key) + if not feature_enabled: + return False + + # ... the course must have a verified mode + course_has_verified_mode = CourseMode.verified_mode_for_course(course_key) + if not course_has_verified_mode: + return False + + # ... and the user must be enrolled as audit + enrollment = CourseEnrollment.get_enrollment(user, course_key) + user_enrolled_as_audit = enrollment is not None and enrollment.mode == CourseMode.AUDIT + if not user_enrolled_as_audit: + return False + + return True diff --git a/lms/djangoapps/courseware/tests/test_views.py b/lms/djangoapps/courseware/tests/test_views.py index 4e3d1be9bddc..4a3f16523ce5 100644 --- a/lms/djangoapps/courseware/tests/test_views.py +++ b/lms/djangoapps/courseware/tests/test_views.py @@ -79,6 +79,7 @@ COURSEWARE_MICROFRONTEND_ENABLE_NAVIGATION_SIDEBAR, COURSEWARE_MICROFRONTEND_SEARCH_ENABLED, COURSEWARE_OPTIMIZED_RENDER_XBLOCK, + ENABLE_UNIFIED_SITE_AND_TRANSLATION_LANGUAGE, ) from completion.waffle import ENABLE_COMPLETION_TRACKING_SWITCH from lms.djangoapps.courseware.user_state_client import DjangoXBlockUserStateClient @@ -1781,6 +1782,14 @@ def assert_progress_page_show_grades(self, response, show_correctness, due_date, (ShowCorrectness.PAST_DUE, TODAY, True), (ShowCorrectness.PAST_DUE, TOMORROW, False), (ShowCorrectness.PAST_DUE, TOMORROW, True), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, None, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, None, True), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, YESTERDAY, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, YESTERDAY, True), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TODAY, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TODAY, True), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TOMORROW, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TOMORROW, True), ) @ddt.unpack def test_progress_page_no_problem_scores(self, show_correctness, due_date_name, graded): @@ -1821,6 +1830,14 @@ def test_progress_page_no_problem_scores(self, show_correctness, due_date_name, (ShowCorrectness.PAST_DUE, TODAY, True, True), (ShowCorrectness.PAST_DUE, TOMORROW, False, False), (ShowCorrectness.PAST_DUE, TOMORROW, True, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, None, False, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, None, True, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, YESTERDAY, False, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, YESTERDAY, True, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TODAY, False, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TODAY, True, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TOMORROW, False, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TOMORROW, True, False), ) @ddt.unpack def test_progress_page_hide_scores_from_learner(self, show_correctness, due_date_name, graded, show_grades): @@ -1873,11 +1890,20 @@ def test_progress_page_hide_scores_from_learner(self, show_correctness, due_date (ShowCorrectness.PAST_DUE, TODAY, True, True), (ShowCorrectness.PAST_DUE, TOMORROW, False, True), (ShowCorrectness.PAST_DUE, TOMORROW, True, True), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, None, False, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, None, True, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, YESTERDAY, False, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, YESTERDAY, True, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TODAY, False, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TODAY, True, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TOMORROW, False, False), + (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TOMORROW, True, False), ) @ddt.unpack def test_progress_page_hide_scores_from_staff(self, show_correctness, due_date_name, graded, show_grades): """ - Test that problem scores are hidden from staff viewing a learner's progress page only if show_correctness=never. + Test that problem scores are hidden from staff viewing a learner's progress page only if show_correctness is + never or never_but_include_grade. """ due_date = self.DATES[due_date_name] self.setup_course(show_correctness=show_correctness, due_date=due_date, graded=graded) @@ -3421,3 +3447,25 @@ def test_course_about_redirect_to_mfe(self, catalog_mfe_enabled, expected_redire assert response.url == "http://example.com/catalog/courses/{}/about".format(self.course.id) else: assert response.status_code == 200 + + +@ddt.ddt +class UnifiedSiteAndTranslationLanguageEnabledViewTests(TestCase): + """ + Tests for the unified_site_and_translation_language_enabled view + """ + @override_waffle_flag(ENABLE_UNIFIED_SITE_AND_TRANSLATION_LANGUAGE, True) + def test_view_logged_out(self): + url = reverse('unified_translations_enabled_view') + self.client.logout() + response = self.client.get(url) + assert response.status_code == 302 + + @ddt.data(True, False) + def test_view(self, enabled): + url = reverse('unified_translations_enabled_view') + user = UserFactory.create() + assert self.client.login(username=user.username, password=TEST_PASSWORD) + with override_waffle_flag(ENABLE_UNIFIED_SITE_AND_TRANSLATION_LANGUAGE, enabled): + response = self.client.get(url) + assert response.json()['enabled'] == enabled diff --git a/lms/djangoapps/courseware/toggles.py b/lms/djangoapps/courseware/toggles.py index f9f083cad42e..4d3ba067162d 100644 --- a/lms/djangoapps/courseware/toggles.py +++ b/lms/djangoapps/courseware/toggles.py @@ -2,7 +2,7 @@ Toggles for courseware in-course experience. """ -from edx_toggles.toggles import SettingToggle, WaffleSwitch +from edx_toggles.toggles import SettingToggle, WaffleSwitch, WaffleFlag from openedx.core.djangoapps.waffle_utils import CourseWaffleFlag @@ -168,6 +168,19 @@ f'{WAFFLE_FLAG_NAMESPACE}.discovery_default_language_filter', __name__ ) +# .. toggle_name: courseware.unify_site_and_translation_language +# .. toggle_implementation: WaffleFlag +# .. toggle_default: False +# .. toggle_description: Update LMS to use site language for xpert unit translations and enable new header site language switcher. +# .. toggle_use_cases: opt_in +# .. toggle_creation_date: 2026-01-08 +# .. toggle_target_removal_date: None +# .. toggle_warning: None. +# .. toggle_tickets: https://github.com/edx/edx-platform/pull/81 +ENABLE_UNIFIED_SITE_AND_TRANSLATION_LANGUAGE = WaffleFlag( + f'{WAFFLE_FLAG_NAMESPACE}.unify_site_and_translation_language', __name__ +) + def course_exit_page_is_active(course_key): return COURSEWARE_MICROFRONTEND_COURSE_EXIT_PAGE.is_enabled(course_key) @@ -202,3 +215,10 @@ def courseware_disable_navigation_sidebar_blocks_caching(course_key=None): Return whether the courseware.disable_navigation_sidebar_blocks_caching flag is on. """ return COURSEWARE_MICROFRONTEND_NAVIGATION_SIDEBAR_BLOCKS_DISABLE_CACHING.is_enabled(course_key) + + +def unified_site_and_translation_language_is_enabled(): + """ + Return whether the courseware.unify_site_and_translation_language flag is on. + """ + return ENABLE_UNIFIED_SITE_AND_TRANSLATION_LANGUAGE.is_enabled() diff --git a/lms/djangoapps/courseware/views/views.py b/lms/djangoapps/courseware/views/views.py index 2c89800d82b6..ccee9a0fa729 100644 --- a/lms/djangoapps/courseware/views/views.py +++ b/lms/djangoapps/courseware/views/views.py @@ -159,6 +159,7 @@ from ..toggles import ( COURSEWARE_OPTIMIZED_RENDER_XBLOCK, ENABLE_COURSE_DISCOVERY_DEFAULT_LANGUAGE_FILTER, + unified_site_and_translation_language_is_enabled, ) log = logging.getLogger("edx.courseware") @@ -2402,3 +2403,12 @@ def courseware_mfe_navigation_sidebar_toggles(request, course_id=None): # Add completion tracking status for the sidebar use while a global place for switches is put in place "enable_completion_tracking": ENABLE_COMPLETION_TRACKING_SWITCH.is_enabled() }) + + +@login_required +@api_view(['GET']) +def unified_site_and_translation_language_enabled(request): + """ + Simple GET endpoint to expose whether the user/course has access to the unified translations feature + """ + return JsonResponse({'enabled': unified_site_and_translation_language_is_enabled()}) diff --git a/lms/djangoapps/discussion/rest_api/api.py b/lms/djangoapps/discussion/rest_api/api.py index b87852c16cfa..ac131fa4c7e0 100644 --- a/lms/djangoapps/discussion/rest_api/api.py +++ b/lms/djangoapps/discussion/rest_api/api.py @@ -1,17 +1,17 @@ """ Discussion API internal interface """ + from __future__ import annotations import itertools +import logging import re from collections import defaultdict from datetime import datetime - from enum import Enum from typing import Dict, Iterable, List, Literal, Optional, Set, Tuple from urllib.parse import urlencode, urlunparse -from pytz import UTC from django.conf import settings from django.contrib.auth import get_user_model @@ -19,24 +19,27 @@ from django.db.models import Q from django.http import Http404 from django.urls import reverse +from django.utils.html import strip_tags from edx_django_utils.monitoring import function_trace from opaque_keys import InvalidKeyError from opaque_keys.edx.locator import CourseKey +from pytz import UTC from rest_framework import status from rest_framework.exceptions import PermissionDenied from rest_framework.request import Request from rest_framework.response import Response -from common.djangoapps.student.roles import ( - CourseInstructorRole, - CourseStaffRole, -) - +from common.djangoapps.student.roles import CourseInstructorRole, CourseStaffRole +from forum import api as forum_api from lms.djangoapps.course_api.blocks.api import get_blocks from lms.djangoapps.courseware.courses import get_course_with_access from lms.djangoapps.courseware.exceptions import CourseAccessRedirect from lms.djangoapps.discussion.rate_limit import is_content_creation_rate_limited -from lms.djangoapps.discussion.toggles import ENABLE_DISCUSSIONS_MFE, ONLY_VERIFIED_USERS_CAN_POST +from lms.djangoapps.discussion.toggles import ( + ENABLE_DISCUSSIONS_MFE, + ENABLE_DISCUSSION_BAN, + ONLY_VERIFIED_USERS_CAN_POST, +) from lms.djangoapps.discussion.views import is_privileged_user from openedx.core.djangoapps.discussions.models import ( DiscussionsConfiguration, @@ -48,12 +51,12 @@ from openedx.core.djangoapps.django_comment_common.comment_client.comment import Comment from openedx.core.djangoapps.django_comment_common.comment_client.course import ( get_course_commentable_counts, - get_course_user_stats + get_course_user_stats, ) from openedx.core.djangoapps.django_comment_common.comment_client.thread import Thread from openedx.core.djangoapps.django_comment_common.comment_client.utils import ( CommentClient500Error, - CommentClientRequestError + CommentClientRequestError, ) from openedx.core.djangoapps.django_comment_common.models import ( FORUM_ROLE_ADMINISTRATOR, @@ -61,13 +64,13 @@ FORUM_ROLE_GROUP_MODERATOR, FORUM_ROLE_MODERATOR, CourseDiscussionSettings, - Role + Role, ) from openedx.core.djangoapps.django_comment_common.signals import ( comment_created, comment_deleted, - comment_endorsed, comment_edited, + comment_endorsed, comment_flagged, comment_voted, thread_created, @@ -75,11 +78,15 @@ thread_edited, thread_flagged, thread_followed, + thread_unfollowed, thread_voted, - thread_unfollowed ) from openedx.core.djangoapps.user_api.accounts.api import get_account_settings -from openedx.core.lib.exceptions import CourseNotFoundError, DiscussionNotFoundError, PageNotFoundError +from openedx.core.lib.exceptions import ( + CourseNotFoundError, + DiscussionNotFoundError, + PageNotFoundError, +) from xmodule.course_block import CourseBlock from xmodule.modulestore import ModuleStoreEnum from xmodule.modulestore.django import modulestore @@ -88,21 +95,27 @@ from ..django_comment_client.base.views import ( track_comment_created_event, track_comment_deleted_event, + track_discussion_reported_event, + track_discussion_unreported_event, + track_forum_search_event, track_thread_created_event, track_thread_deleted_event, + track_thread_followed_event, track_thread_viewed_event, track_voted_event, - track_discussion_reported_event, - track_discussion_unreported_event, - track_forum_search_event, track_thread_followed_event ) from ..django_comment_client.utils import ( get_group_id_for_user, get_user_role_names, has_discussion_privileges, - is_commentable_divided + is_commentable_divided, +) +from .exceptions import ( + CommentNotFoundError, + DiscussionBlackOutException, + DiscussionDisabledError, + ThreadNotFoundError, ) -from .exceptions import CommentNotFoundError, DiscussionBlackOutException, DiscussionDisabledError, ThreadNotFoundError from .forms import CommentActionsForm, ThreadActionsForm, UserOrdering from .pagination import DiscussionAPIPagination from .permissions import ( @@ -110,7 +123,7 @@ can_take_action_on_spam, get_editable_fields, get_initializable_comment_fields, - get_initializable_thread_fields + get_initializable_thread_fields, ) from .serializers import ( CommentSerializer, @@ -119,20 +132,23 @@ ThreadSerializer, TopicOrdering, UserStatsSerializer, - get_context + get_context, ) from .utils import ( AttributeDict, add_stats_for_users_with_no_discussion_content, + can_user_notify_all_learners, create_blocks_params, discussion_open_for_user, + get_captcha_site_key_by_platform, get_usernames_for_course, get_usernames_from_search_string, - set_attribute, + is_captcha_enabled, is_posting_allowed, - can_user_notify_all_learners, is_captcha_enabled, get_captcha_site_key_by_platform + set_attribute, ) +log = logging.getLogger(__name__) User = get_user_model() ThreadType = Literal["discussion", "question"] @@ -166,11 +182,14 @@ class DiscussionEntity(Enum): """ Enum for different types of discussion related entities """ - thread = 'thread' - comment = 'comment' + + thread = "thread" + comment = "comment" -def _get_course(course_key: CourseKey, user: User, check_tab: bool = True) -> CourseBlock: +def _get_course( + course_key: CourseKey, user: User, check_tab: bool = True +) -> CourseBlock: """ Get the course block, raising CourseNotFoundError if the course is not found or the user cannot access forums for the course, and DiscussionDisabledError if the @@ -188,14 +207,16 @@ def _get_course(course_key: CourseKey, user: User, check_tab: bool = True) -> Co CourseBlock: course object """ try: - course = get_course_with_access(user, 'load', course_key, check_if_enrolled=True) + course = get_course_with_access( + user, "load", course_key, check_if_enrolled=True + ) except (Http404, CourseAccessRedirect) as err: # Convert 404s into CourseNotFoundErrors. # Raise course not found if the user cannot access the course raise CourseNotFoundError("Course not found.") from err if check_tab: - discussion_tab = CourseTabList.get_tab_by_type(course.tabs, 'discussion') + discussion_tab = CourseTabList.get_tab_by_type(course.tabs, "discussion") if not (discussion_tab and discussion_tab.is_enabled(course, user)): raise DiscussionDisabledError("Discussion is disabled for the course.") @@ -216,22 +237,34 @@ def _get_thread_and_context(request, thread_id, retrieve_kwargs=None, course_id= retrieve_kwargs["with_responses"] = False if "mark_as_read" not in retrieve_kwargs: retrieve_kwargs["mark_as_read"] = False - cc_thread = Thread(id=thread_id).retrieve(course_id=course_id, **retrieve_kwargs) + cc_thread = Thread(id=thread_id).retrieve( + course_id=course_id, **retrieve_kwargs + ) course_key = CourseKey.from_string(cc_thread["course_id"]) course = _get_course(course_key, request.user) context = get_context(course, request, cc_thread) - if retrieve_kwargs.get("flagged_comments") and not context["has_moderation_privilege"]: + if ( + retrieve_kwargs.get("flagged_comments") + and not context["has_moderation_privilege"] + ): raise ValidationError("Only privileged users can request flagged comments") course_discussion_settings = CourseDiscussionSettings.get(course_key) if ( - not context["has_moderation_privilege"] and - cc_thread["group_id"] and - is_commentable_divided(course.id, cc_thread["commentable_id"], course_discussion_settings) + not context["has_moderation_privilege"] + and cc_thread["group_id"] + and is_commentable_divided( + course.id, cc_thread["commentable_id"], course_discussion_settings + ) ): - requester_group_id = get_group_id_for_user(request.user, course_discussion_settings) - if requester_group_id is not None and cc_thread["group_id"] != requester_group_id: + requester_group_id = get_group_id_for_user( + request.user, course_discussion_settings + ) + if ( + requester_group_id is not None + and cc_thread["group_id"] != requester_group_id + ): raise ThreadNotFoundError("Thread not found.") return cc_thread, context except CommentClientRequestError as err: @@ -264,8 +297,8 @@ def _is_user_author_or_privileged(cc_content, context): Boolean """ return ( - context["has_moderation_privilege"] or - context["cc_requester"]["id"] == cc_content["user_id"] + context["has_moderation_privilege"] + or context["cc_requester"]["id"] == cc_content["user_id"] ) @@ -275,11 +308,13 @@ def get_thread_list_url(request, course_key, topic_id_list=None, following=False """ path = reverse("thread-list") query_list = ( - [("course_id", str(course_key))] + - [("topic_id", topic_id) for topic_id in topic_id_list or []] + - ([("following", following)] if following else []) + [("course_id", str(course_key))] + + [("topic_id", topic_id) for topic_id in topic_id_list or []] + + ([("following", following)] if following else []) + ) + return request.build_absolute_uri( + urlunparse(("", "", path, "", urlencode(query_list), "")) ) - return request.build_absolute_uri(urlunparse(("", "", path, "", urlencode(query_list), ""))) def get_course(request, course_key, check_tab=True): @@ -324,23 +359,37 @@ def _format_datetime(dt): the substitution... though really, that would probably break mobile client parsing of the dates as well. :-P """ - return dt.isoformat().replace('+00:00', 'Z') + return dt.isoformat().replace("+00:00", "Z") course = _get_course(course_key, request.user, check_tab=check_tab) user_roles = get_user_role_names(request.user, course_key) course_config = DiscussionsConfiguration.get(course_key) EDIT_REASON_CODES = getattr(settings, "DISCUSSION_MODERATION_EDIT_REASON_CODES", {}) - CLOSE_REASON_CODES = getattr(settings, "DISCUSSION_MODERATION_CLOSE_REASON_CODES", {}) + CLOSE_REASON_CODES = getattr( + settings, "DISCUSSION_MODERATION_CLOSE_REASON_CODES", {} + ) is_posting_enabled = is_posting_allowed( - course_config.posting_restrictions, - course.get_discussion_blackout_datetimes() + course_config.posting_restrictions, course.get_discussion_blackout_datetimes() ) - discussion_tab = CourseTabList.get_tab_by_type(course.tabs, 'discussion') + discussion_tab = CourseTabList.get_tab_by_type(course.tabs, "discussion") is_course_staff = CourseStaffRole(course_key).has_user(request.user) is_course_admin = CourseInstructorRole(course_key).has_user(request.user) + + # Check if the user is banned from discussions + is_user_banned_func = getattr(forum_api, 'is_user_banned', None) + is_user_banned = False + # Only check ban status if feature flag is enabled + if ENABLE_DISCUSSION_BAN.is_enabled(course_key) and is_user_banned_func is not None: + try: + is_user_banned = is_user_banned_func(request.user, course_key) + except Exception: # pylint: disable=broad-except + # If ban check fails, default to False + is_user_banned = False + return { "id": str(course_key), "is_posting_enabled": is_posting_enabled, + "is_user_banned": is_user_banned, "blackouts": [ { "start": _format_datetime(blackout["start"]), @@ -349,7 +398,9 @@ def _format_datetime(dt): for blackout in course.get_discussion_blackout_datetimes() ], "thread_list_url": get_thread_list_url(request, course_key), - "following_thread_list_url": get_thread_list_url(request, course_key, following=True), + "following_thread_list_url": get_thread_list_url( + request, course_key, following=True + ), "topics_url": request.build_absolute_uri( reverse("course_topics", kwargs={"course_id": course_key}) ), @@ -357,18 +408,23 @@ def _format_datetime(dt): "allow_anonymous_to_peers": course.allow_anonymous_to_peers, "user_roles": user_roles, "has_bulk_delete_privileges": can_take_action_on_spam(request.user, course_key), - "has_moderation_privileges": bool(user_roles & { - FORUM_ROLE_ADMINISTRATOR, - FORUM_ROLE_MODERATOR, - FORUM_ROLE_COMMUNITY_TA, - }), + "has_moderation_privileges": bool( + user_roles + & { + FORUM_ROLE_ADMINISTRATOR, + FORUM_ROLE_MODERATOR, + FORUM_ROLE_COMMUNITY_TA, + } + ), "is_group_ta": bool(user_roles & {FORUM_ROLE_GROUP_MODERATOR}), "is_user_admin": request.user.is_staff, "is_course_staff": is_course_staff, "is_course_admin": is_course_admin, "provider": course_config.provider_type, "enable_in_context": course_config.enable_in_context, - "group_at_subsection": course_config.plugin_configuration.get("group_at_subsection", False), + "group_at_subsection": course_config.plugin_configuration.get( + "group_at_subsection", False + ), "edit_reasons": [ {"code": reason_code, "label": label} for (reason_code, label) in EDIT_REASON_CODES.items() @@ -377,17 +433,24 @@ def _format_datetime(dt): {"code": reason_code, "label": label} for (reason_code, label) in CLOSE_REASON_CODES.items() ], - 'show_discussions': bool(discussion_tab and discussion_tab.is_enabled(course, request.user)), - 'is_notify_all_learners_enabled': can_user_notify_all_learners( + "show_discussions": bool( + discussion_tab and discussion_tab.is_enabled(course, request.user) + ), + "is_notify_all_learners_enabled": can_user_notify_all_learners( user_roles, is_course_staff, is_course_admin ), - 'captcha_settings': { - 'enabled': is_captcha_enabled(course_key), - 'site_key': get_captcha_site_key_by_platform('web'), + "captcha_settings": { + "enabled": is_captcha_enabled(course_key), + "site_key": get_captcha_site_key_by_platform("web"), }, "is_email_verified": request.user.is_active, - "only_verified_users_can_post": ONLY_VERIFIED_USERS_CAN_POST.is_enabled(course_key), - "content_creation_rate_limited": is_content_creation_rate_limited(request, course_key, increment=False), + "only_verified_users_can_post": ONLY_VERIFIED_USERS_CAN_POST.is_enabled( + course_key + ), + "content_creation_rate_limited": is_content_creation_rate_limited( + request, course_key, increment=False + ), + "enable_discussion_ban": ENABLE_DISCUSSION_BAN.is_enabled(course_key), } @@ -440,7 +503,7 @@ def convert(text): return text def alphanum_key(key): - return [convert(c) for c in re.split('([0-9]+)', key)] + return [convert(c) for c in re.split("([0-9]+)", key)] return sorted(category_list, key=alphanum_key) @@ -482,7 +545,7 @@ def get_non_courseware_topics( course_key: CourseKey, course: CourseBlock, topic_ids: Optional[List[str]], - thread_counts: Dict[str, Dict[str, int]] + thread_counts: Dict[str, Dict[str, int]], ) -> Tuple[List[Dict], Set[str]]: """ Returns a list of topic trees that are not linked to courseware. @@ -506,13 +569,17 @@ def get_non_courseware_topics( existing_topic_ids = set() topics = list(course.discussion_topics.items()) for name, entry in topics: - if not topic_ids or entry['id'] in topic_ids: + if not topic_ids or entry["id"] in topic_ids: discussion_topic = DiscussionTopic( - entry["id"], name, get_thread_list_url(request, course_key, [entry["id"]]), + entry["id"], + name, + get_thread_list_url(request, course_key, [entry["id"]]), None, - thread_counts.get(entry["id"]) + thread_counts.get(entry["id"]), + ) + non_courseware_topics.append( + DiscussionTopicSerializer(discussion_topic).data ) - non_courseware_topics.append(DiscussionTopicSerializer(discussion_topic).data) if topic_ids and entry["id"] in topic_ids: existing_topic_ids.add(entry["id"]) @@ -520,7 +587,9 @@ def get_non_courseware_topics( return non_courseware_topics, existing_topic_ids -def get_course_topics(request: Request, course_key: CourseKey, topic_ids: Optional[Set[str]] = None): +def get_course_topics( + request: Request, course_key: CourseKey, topic_ids: Optional[Set[str]] = None +): """ Returns the course topic listing for the given course and user; filtered by 'topic_ids' list if given. @@ -544,15 +613,25 @@ def get_course_topics(request: Request, course_key: CourseKey, topic_ids: Option courseware_topics, existing_courseware_topic_ids = get_courseware_topics( request, course_key, course, topic_ids, thread_counts ) - non_courseware_topics, existing_non_courseware_topic_ids = get_non_courseware_topics( - request, course_key, course, topic_ids, thread_counts, + non_courseware_topics, existing_non_courseware_topic_ids = ( + get_non_courseware_topics( + request, + course_key, + course, + topic_ids, + thread_counts, + ) ) if topic_ids: - not_found_topic_ids = topic_ids - (existing_courseware_topic_ids | existing_non_courseware_topic_ids) + not_found_topic_ids = topic_ids - ( + existing_courseware_topic_ids | existing_non_courseware_topic_ids + ) if not_found_topic_ids: raise DiscussionNotFoundError( - "Discussion not found for '{}'.".format(", ".join(str(id) for id in not_found_topic_ids)) + "Discussion not found for '{}'.".format( + ", ".join(str(id) for id in not_found_topic_ids) + ) ) return { @@ -567,17 +646,19 @@ def get_v2_non_courseware_topics_as_v1(request, course_key, topics): """ non_courseware_topics = [] for topic in topics: - if topic.get('usage_key', '') is None: - for key in ['usage_key', 'enabled_in_context']: + if topic.get("usage_key", "") is None: + for key in ["usage_key", "enabled_in_context"]: topic.pop(key) - topic.update({ - 'children': [], - 'thread_list_url': get_thread_list_url( - request, - course_key, - topic.get('id'), - ) - }) + topic.update( + { + "children": [], + "thread_list_url": get_thread_list_url( + request, + course_key, + topic.get("id"), + ), + } + ) non_courseware_topics.append(topic) return non_courseware_topics @@ -589,23 +670,25 @@ def get_v2_courseware_topics_as_v1(request, course_key, sequentials, topics): courseware_topics = [] for sequential in sequentials: children = [] - for child in sequential.get('children', []): + for child in sequential.get("children", []): for topic in topics: - if child == topic.get('usage_key'): - topic.update({ - 'children': [], - 'thread_list_url': get_thread_list_url( - request, - course_key, - [topic.get('id')], - ) - }) - topic.pop('enabled_in_context') + if child == topic.get("usage_key"): + topic.update( + { + "children": [], + "thread_list_url": get_thread_list_url( + request, + course_key, + [topic.get("id")], + ), + } + ) + topic.pop("enabled_in_context") children.append(AttributeDict(topic)) discussion_topic = DiscussionTopic( None, - sequential.get('display_name'), + sequential.get("display_name"), get_thread_list_url( request, course_key, @@ -618,7 +701,7 @@ def get_v2_courseware_topics_as_v1(request, course_key, sequentials, topics): courseware_topics = [ courseware_topic for courseware_topic in courseware_topics - if courseware_topic.get('children', []) + if courseware_topic.get("children", []) ] return courseware_topics @@ -635,20 +718,21 @@ def get_v2_course_topics_as_v1( blocks_params = create_blocks_params(course_usage_key, request.user) blocks = get_blocks( request, - blocks_params['usage_key'], - blocks_params['user'], - blocks_params['depth'], - blocks_params['nav_depth'], - blocks_params['requested_fields'], - blocks_params['block_counts'], - blocks_params['student_view_data'], - blocks_params['return_type'], - blocks_params['block_types_filter'], + blocks_params["usage_key"], + blocks_params["user"], + blocks_params["depth"], + blocks_params["nav_depth"], + blocks_params["requested_fields"], + blocks_params["block_counts"], + blocks_params["student_view_data"], + blocks_params["return_type"], + blocks_params["block_types_filter"], hide_access_denials=False, - )['blocks'] + )["blocks"] - sequentials = [value for _, value in blocks.items() - if value.get('type') == "sequential"] + sequentials = [ + value for _, value in blocks.items() if value.get("type") == "sequential" + ] topics = get_course_topics_v2(course_key, request.user, topic_ids) non_courseware_topics = get_v2_non_courseware_topics_as_v1( @@ -705,24 +789,29 @@ def get_course_topics_v2( # Check access to the course store = modulestore() _get_course(course_key, user=user, check_tab=False) - user_is_privileged = user.is_staff or user.roles.filter( - course_id=course_key, - name__in=[ - FORUM_ROLE_MODERATOR, - FORUM_ROLE_COMMUNITY_TA, - FORUM_ROLE_ADMINISTRATOR, - ] - ).exists() + user_is_privileged = ( + user.is_staff + or user.roles.filter( + course_id=course_key, + name__in=[ + FORUM_ROLE_MODERATOR, + FORUM_ROLE_COMMUNITY_TA, + FORUM_ROLE_ADMINISTRATOR, + ], + ).exists() + ) with store.branch_setting(ModuleStoreEnum.Branch.draft_preferred, course_key): blocks = store.get_items( course_key, - qualifiers={'category': 'vertical'}, - fields=['usage_key', 'discussion_enabled', 'display_name'], + qualifiers={"category": "vertical"}, + fields=["usage_key", "discussion_enabled", "display_name"], ) accessible_vertical_keys = [] for block in blocks: - if block.discussion_enabled and (not block.visible_to_staff_only or user_is_privileged): + if block.discussion_enabled and ( + not block.visible_to_staff_only or user_is_privileged + ): accessible_vertical_keys.append(block.usage_key) accessible_vertical_keys.append(None) @@ -732,9 +821,13 @@ def get_course_topics_v2( ) if user_is_privileged: - topics_query = topics_query.filter(Q(usage_key__in=accessible_vertical_keys) | Q(enabled_in_context=False)) + topics_query = topics_query.filter( + Q(usage_key__in=accessible_vertical_keys) | Q(enabled_in_context=False) + ) else: - topics_query = topics_query.filter(usage_key__in=accessible_vertical_keys, enabled_in_context=True) + topics_query = topics_query.filter( + usage_key__in=accessible_vertical_keys, enabled_in_context=True + ) if topic_ids: topics_query = topics_query.filter(external_id__in=topic_ids) @@ -746,11 +839,13 @@ def get_course_topics_v2( reverse=True, ) elif order_by == TopicOrdering.NAME: - topics_query = topics_query.order_by('title') + topics_query = topics_query.order_by("title") else: - topics_query = topics_query.order_by('ordering') + topics_query = topics_query.order_by("ordering") - topics_data = DiscussionTopicSerializerV2(topics_query, many=True, context={"thread_counts": thread_counts}).data + topics_data = DiscussionTopicSerializerV2( + topics_query, many=True, context={"thread_counts": thread_counts} + ).data return [ topic_data for topic_data in topics_data @@ -777,7 +872,7 @@ def _get_user_profile_dict(request, usernames): else: username_list = [] user_profile_details = get_account_settings(request, username_list) - return {user['username']: user for user in user_profile_details} + return {user["username"]: user for user in user_profile_details} def _user_profile(user_profile): @@ -785,11 +880,7 @@ def _user_profile(user_profile): Returns the user profile object. For now, this just comprises the profile_image details. """ - return { - 'profile': { - 'image': user_profile['profile_image'] - } - } + return {"profile": {"image": user_profile["profile_image"]}} def _get_users(discussion_entity_type, discussion_entity, username_profile_dict): @@ -807,22 +898,28 @@ def _get_users(discussion_entity_type, discussion_entity, username_profile_dict) A dict of users with username as key and user profile details as value. """ users = {} - if discussion_entity['author']: - user_profile = username_profile_dict.get(discussion_entity['author']) + if discussion_entity["author"]: + user_profile = username_profile_dict.get(discussion_entity["author"]) if user_profile: - users[discussion_entity['author']] = _user_profile(user_profile) + users[discussion_entity["author"]] = _user_profile(user_profile) if ( discussion_entity_type == DiscussionEntity.comment - and discussion_entity['endorsed'] - and discussion_entity['endorsed_by'] + and discussion_entity["endorsed"] + and discussion_entity["endorsed_by"] ): - users[discussion_entity['endorsed_by']] = _user_profile(username_profile_dict[discussion_entity['endorsed_by']]) + users[discussion_entity["endorsed_by"]] = _user_profile( + username_profile_dict[discussion_entity["endorsed_by"]] + ) return users def _add_additional_response_fields( - request, serialized_discussion_entities, usernames, discussion_entity_type, include_profile_image + request, + serialized_discussion_entities, + usernames, + discussion_entity_type, + include_profile_image, ): """ Adds additional data to serialized discussion thread/comment. @@ -840,9 +937,13 @@ def _add_additional_response_fields( A list of serialized discussion thread/comment with additional data if requested. """ if include_profile_image: - username_profile_dict = _get_user_profile_dict(request, usernames=','.join(usernames)) + username_profile_dict = _get_user_profile_dict( + request, usernames=",".join(usernames) + ) for discussion_entity in serialized_discussion_entities: - discussion_entity['users'] = _get_users(discussion_entity_type, discussion_entity, username_profile_dict) + discussion_entity["users"] = _get_users( + discussion_entity_type, discussion_entity, username_profile_dict + ) return serialized_discussion_entities @@ -851,10 +952,12 @@ def _include_profile_image(requested_fields): """ Returns True if requested_fields list has 'profile_image' entity else False """ - return requested_fields and 'profile_image' in requested_fields + return requested_fields and "profile_image" in requested_fields -def _serialize_discussion_entities(request, context, discussion_entities, requested_fields, discussion_entity_type): +def _serialize_discussion_entities( + request, context, discussion_entities, requested_fields, discussion_entity_type +): """ It serializes Discussion Entity (Thread or Comment) and add additional data if requested. @@ -885,14 +988,19 @@ def _serialize_discussion_entities(request, context, discussion_entities, reques results.append(serialized_entity) if include_profile_image: - if serialized_entity['author'] and serialized_entity['author'] not in usernames: - usernames.append(serialized_entity['author']) if ( - 'endorsed' in serialized_entity and serialized_entity['endorsed'] and - 'endorsed_by' in serialized_entity and - serialized_entity['endorsed_by'] and serialized_entity['endorsed_by'] not in usernames + serialized_entity["author"] + and serialized_entity["author"] not in usernames + ): + usernames.append(serialized_entity["author"]) + if ( + "endorsed" in serialized_entity + and serialized_entity["endorsed"] + and "endorsed_by" in serialized_entity + and serialized_entity["endorsed_by"] + and serialized_entity["endorsed_by"] not in usernames ): - usernames.append(serialized_entity['endorsed_by']) + usernames.append(serialized_entity["endorsed_by"]) results = _add_additional_response_fields( request, results, usernames, discussion_entity_type, include_profile_image @@ -916,6 +1024,7 @@ def get_thread_list( order_direction: Literal["desc"] = "desc", requested_fields: Optional[List[Literal["profile_image"]]] = None, count_flagged: bool = None, + show_deleted: bool = False, ): """ Return the list of all discussion threads pertaining to the given course @@ -959,20 +1068,31 @@ def get_thread_list( CourseNotFoundError: if the requesting user does not have access to the requested course PageNotFoundError: if page requested is beyond the last """ - exclusive_param_count = sum(1 for param in [topic_id_list, text_search, following] if param) + exclusive_param_count = sum( + 1 for param in [topic_id_list, text_search, following] if param + ) if exclusive_param_count > 1: # pragma: no cover - raise ValueError("More than one mutually exclusive param passed to get_thread_list") + raise ValueError( + "More than one mutually exclusive param passed to get_thread_list" + ) - cc_map = {"last_activity_at": "activity", "comment_count": "comments", "vote_count": "votes"} + cc_map = { + "last_activity_at": "activity", + "comment_count": "comments", + "vote_count": "votes", + } if order_by not in cc_map: - raise ValidationError({ - "order_by": - [f"Invalid value. '{order_by}' must be 'last_activity_at', 'comment_count', or 'vote_count'"] - }) + raise ValidationError( + { + "order_by": [ + f"Invalid value. '{order_by}' must be 'last_activity_at', 'comment_count', or 'vote_count'" + ] + } + ) if order_direction != "desc": - raise ValidationError({ - "order_direction": [f"Invalid value. '{order_direction}' must be 'desc'"] - }) + raise ValidationError( + {"order_direction": [f"Invalid value. '{order_direction}' must be 'desc'"]} + ) course = _get_course(course_key, request.user) context = get_context(course, request) @@ -984,13 +1104,21 @@ def get_thread_list( except User.DoesNotExist: # Raising an error for a missing user leaks the presence of a username, # so just return an empty response. - return DiscussionAPIPagination(request, 0, 1).get_paginated_response({ - "results": [], - "text_search_rewrite": None, - }) + return DiscussionAPIPagination(request, 0, 1).get_paginated_response( + { + "results": [], + "text_search_rewrite": None, + } + ) if count_flagged and not context["has_moderation_privilege"]: - raise PermissionDenied("`count_flagged` can only be set by users with moderator access or higher.") + raise PermissionDenied( + "`count_flagged` can only be set by users with moderator access or higher." + ) + if show_deleted and not context["has_moderation_privilege"]: + raise PermissionDenied( + "`show_deleted` can only be set by users with moderator access or higher." + ) group_id = None allowed_roles = [ @@ -1010,7 +1138,9 @@ def get_thread_list( not context["has_moderation_privilege"] or request.user.id in context["ta_user_ids"] ): - group_id = get_group_id_for_user(request.user, CourseDiscussionSettings.get(course.id)) + group_id = get_group_id_for_user( + request.user, CourseDiscussionSettings.get(course.id) + ) query_params = { "user_id": str(request.user.id), @@ -1023,21 +1153,24 @@ def get_thread_list( "flagged": flagged, "thread_type": thread_type, "count_flagged": count_flagged, + "show_deleted": show_deleted, } if view: if view in ["unread", "unanswered", "unresponded"]: query_params[view] = "true" else: - raise ValidationError({ - "view": [f"Invalid value. '{view}' must be 'unread' or 'unanswered'"] - }) + raise ValidationError( + {"view": [f"Invalid value. '{view}' must be 'unread' or 'unanswered'"]} + ) if following: paginated_results = context["cc_requester"].subscribed_threads(query_params) else: query_params["course_id"] = str(course.id) - query_params["commentable_ids"] = ",".join(topic_id_list) if topic_id_list else None + query_params["commentable_ids"] = ( + ",".join(topic_id_list) if topic_id_list else None + ) query_params["text"] = text_search paginated_results = Thread.search(query_params) # The comments service returns the last page of results if the requested @@ -1047,19 +1180,25 @@ def get_thread_list( raise PageNotFoundError("Page not found (No results on this page).") results = _serialize_discussion_entities( - request, context, paginated_results.collection, requested_fields, DiscussionEntity.thread + request, + context, + paginated_results.collection, + requested_fields, + DiscussionEntity.thread, ) paginator = DiscussionAPIPagination( request, paginated_results.page, paginated_results.num_pages, - paginated_results.thread_count + paginated_results.thread_count, + ) + return paginator.get_paginated_response( + { + "results": results, + "text_search_rewrite": paginated_results.corrected_text, + } ) - return paginator.get_paginated_response({ - "results": results, - "text_search_rewrite": paginated_results.corrected_text, - }) def get_learner_active_thread_list(request, course_key, query_params): @@ -1154,49 +1293,101 @@ def get_learner_active_thread_list(request, course_key, query_params): course = _get_course(course_key, request.user) context = get_context(course, request) - group_id = query_params.get('group_id', None) - user_id = query_params.get('user_id', None) - count_flagged = query_params.get('count_flagged', None) + group_id = query_params.get("group_id", None) + user_id = query_params.get("user_id", None) + count_flagged = query_params.get("count_flagged", None) + show_deleted = query_params.get("show_deleted", False) + if isinstance(show_deleted, str): + show_deleted = show_deleted.lower() == "true" + if user_id is None: - return Response({'detail': 'Invalid user id'}, status=status.HTTP_400_BAD_REQUEST) + return Response( + {"detail": "Invalid user id"}, status=status.HTTP_400_BAD_REQUEST + ) if count_flagged and not context["has_moderation_privilege"]: - raise PermissionDenied("count_flagged can only be set by users with moderation roles.") + raise PermissionDenied( + "count_flagged can only be set by users with moderation roles." + ) if "flagged" in query_params.keys() and not context["has_moderation_privilege"]: raise PermissionDenied("Flagged filter is only available for moderators") + if show_deleted and not context["has_moderation_privilege"]: + raise PermissionDenied( + "show_deleted can only be set by users with moderation roles." + ) if group_id is None: comment_client_user = comment_client.User(id=user_id, course_id=course_key) else: - comment_client_user = comment_client.User(id=user_id, course_id=course_key, group_id=group_id) + comment_client_user = comment_client.User( + id=user_id, course_id=course_key, group_id=group_id + ) try: threads, page, num_pages = comment_client_user.active_threads(query_params) threads = set_attribute(threads, "pinned", False) + + # This portion below is temporary until we migrate to forum v2 + filtered_threads = [] + for thread in threads: + try: + forum_thread = forum_api.get_thread( + thread.get("id"), course_id=str(course_key) + ) + is_deleted = forum_thread.get("is_deleted", False) + + if show_deleted and is_deleted: + thread["is_deleted"] = True + thread["deleted_at"] = forum_thread.get("deleted_at") + thread["deleted_by"] = forum_thread.get("deleted_by") + filtered_threads.append(thread) + elif not show_deleted and not is_deleted: + filtered_threads.append(thread) + except Exception as e: # pylint: disable=broad-exception-caught + log.warning( + "Failed to check thread %s deletion status: %s", thread.get("id"), e + ) + if not show_deleted: # Fail safe: include thread for regular users + filtered_threads.append(thread) + results = _serialize_discussion_entities( - request, context, threads, {'profile_image'}, DiscussionEntity.thread + request, + context, + filtered_threads, + {"profile_image"}, + DiscussionEntity.thread, ) paginator = DiscussionAPIPagination( - request, - page, - num_pages, - len(threads) + request, page, num_pages, len(filtered_threads) + ) + return paginator.get_paginated_response( + { + "results": results, + } ) - return paginator.get_paginated_response({ - "results": results, - }) except CommentClient500Error: return DiscussionAPIPagination( request, page_num=1, num_pages=0, - ).get_paginated_response({ - "results": [], - }) + ).get_paginated_response( + { + "results": [], + } + ) -def get_comment_list(request, thread_id, endorsed, page, page_size, flagged=False, requested_fields=None, - merge_question_type_responses=False): +def get_comment_list( + request, + thread_id, + endorsed, + page, + page_size, + flagged=False, + requested_fields=None, + merge_question_type_responses=False, + show_deleted=False, +): """ Return the list of comments in the given thread. @@ -1226,7 +1417,7 @@ def get_comment_list(request, thread_id, endorsed, page, page_size, flagged=Fals discussion.rest_api.views.CommentViewSet for more detail. """ response_skip = page_size * (page - 1) - reverse_order = request.GET.get('reverse_order', False) + reverse_order = request.GET.get("reverse_order", False) from_mfe_sidebar = request.GET.get("enable_in_context_sidebar", False) cc_thread, context = _get_thread_and_context( request, @@ -1239,19 +1430,23 @@ def get_comment_list(request, thread_id, endorsed, page, page_size, flagged=Fals "response_skip": response_skip, "response_limit": page_size, "reverse_order": reverse_order, - "merge_question_type_responses": merge_question_type_responses - } + "merge_question_type_responses": merge_question_type_responses, + }, ) # Responses to discussion threads cannot be separated by endorsed, but # responses to question threads must be separated by endorsed due to the # existing comments service interface if cc_thread["thread_type"] == "question" and not merge_question_type_responses: if endorsed is None: # lint-amnesty, pylint: disable=no-else-raise - raise ValidationError({"endorsed": ["This field is required for question threads."]}) + raise ValidationError( + {"endorsed": ["This field is required for question threads."]} + ) elif endorsed: # CS does not apply resp_skip and resp_limit to endorsed responses # of a question post - responses = cc_thread["endorsed_responses"][response_skip:(response_skip + page_size)] + responses = cc_thread["endorsed_responses"][ + response_skip: (response_skip + page_size) + ] resp_total = len(cc_thread["endorsed_responses"]) else: responses = cc_thread["non_endorsed_responses"] @@ -1260,7 +1455,11 @@ def get_comment_list(request, thread_id, endorsed, page, page_size, flagged=Fals if not merge_question_type_responses: if endorsed is not None: raise ValidationError( - {"endorsed": ["This field may not be specified for discussion threads."]} + { + "endorsed": [ + "This field may not be specified for discussion threads." + ] + } ) responses = cc_thread["children"] resp_total = cc_thread["resp_total"] @@ -1272,9 +1471,21 @@ def get_comment_list(request, thread_id, endorsed, page, page_size, flagged=Fals raise PageNotFoundError("Page not found (No results on this page).") num_pages = (resp_total + page_size - 1) // page_size if resp_total else 1 - results = _serialize_discussion_entities(request, context, responses, requested_fields, DiscussionEntity.comment) + if not show_deleted: + responses = [ + response for response in responses if not response.get("is_deleted", False) + ] + else: + if not context["has_moderation_privilege"]: + raise PermissionDenied( + "`show_deleted` can only be set by users with moderation roles." + ) + + results = _serialize_discussion_entities( + request, context, responses, requested_fields, DiscussionEntity.comment + ) - paginator = DiscussionAPIPagination(request, page, num_pages, resp_total) + paginator = DiscussionAPIPagination(request, page, num_pages, len(responses)) track_thread_viewed_event(request, context["course"], cc_thread, from_mfe_sidebar) return paginator.get_paginated_response(results) @@ -1292,7 +1503,9 @@ def _check_fields(allowed_fields, data, message): ValidationError if the given data contains a key that is not in allowed_fields """ - non_allowed_fields = {field: [message] for field in data.keys() if field not in allowed_fields} + non_allowed_fields = { + field: [message] for field in data.keys() if field not in allowed_fields + } if non_allowed_fields: raise ValidationError(non_allowed_fields) @@ -1314,7 +1527,7 @@ def _check_initializable_thread_fields(data, context): _check_fields( get_initializable_thread_fields(context), data, - "This field is not initializable." + "This field is not initializable.", ) @@ -1335,7 +1548,7 @@ def _check_initializable_comment_fields(data, context): _check_fields( get_initializable_comment_fields(context), data, - "This field is not initializable." + "This field is not initializable.", ) @@ -1345,28 +1558,40 @@ def _check_editable_fields(cc_content, data, context): editable by the requesting user """ _check_fields( - get_editable_fields(cc_content, context), - data, - "This field is not editable." + get_editable_fields(cc_content, context), data, "This field is not editable." ) -def _do_extra_actions(api_content, cc_content, request_fields, actions_form, context, request): +def _do_extra_actions( + api_content, cc_content, request_fields, actions_form, context, request +): """ Perform any necessary additional actions related to content creation or update that require a separate comments service request. """ for field, form_value in actions_form.cleaned_data.items(): - if field in request_fields and field in api_content and form_value != api_content[field]: + if ( + field in request_fields + and field in api_content + and form_value != api_content[field] + ): api_content[field] = form_value if field == "following": - _handle_following_field(form_value, context["cc_requester"], cc_content, request) + _handle_following_field( + form_value, context["cc_requester"], cc_content, request + ) elif field == "abuse_flagged": - _handle_abuse_flagged_field(form_value, context["cc_requester"], cc_content, request) + _handle_abuse_flagged_field( + form_value, context["cc_requester"], cc_content, request + ) elif field == "voted": - _handle_voted_field(form_value, cc_content, api_content, request, context) + _handle_voted_field( + form_value, cc_content, api_content, request, context + ) elif field == "read": - _handle_read_field(api_content, form_value, context["cc_requester"], cc_content) + _handle_read_field( + api_content, form_value, context["cc_requester"], cc_content + ) elif field == "pinned": _handle_pinned_field(form_value, cc_content, context["cc_requester"]) else: @@ -1376,7 +1601,7 @@ def _do_extra_actions(api_content, cc_content, request_fields, actions_form, con def _handle_following_field(form_value, user, cc_content, request): """follow/unfollow thread for the user""" course_key = CourseKey.from_string(cc_content.course_id) - course = get_course_with_access(request.user, 'load', course_key) + course = get_course_with_access(request.user, "load", course_key) if form_value: user.follow(cc_content) else: @@ -1389,15 +1614,19 @@ def _handle_following_field(form_value, user, cc_content, request): def _handle_abuse_flagged_field(form_value, user, cc_content, request): """mark or unmark thread/comment as abused""" course_key = CourseKey.from_string(cc_content.course_id) - course = get_course_with_access(request.user, 'load', course_key) + course = get_course_with_access(request.user, "load", course_key) if form_value: cc_content.flagAbuse(user, cc_content) track_discussion_reported_event(request, course, cc_content) if ENABLE_DISCUSSIONS_MFE.is_enabled(course_key): - if cc_content.type == 'thread': - thread_flagged.send(sender='flag_abuse_for_thread', user=user, post=cc_content) + if cc_content.type == "thread": + thread_flagged.send( + sender="flag_abuse_for_thread", user=user, post=cc_content + ) else: - comment_flagged.send(sender='flag_abuse_for_comment', user=user, post=cc_content) + comment_flagged.send( + sender="flag_abuse_for_comment", user=user, post=cc_content + ) else: remove_all = bool(is_privileged_user(course_key, User.objects.get(id=user.id))) cc_content.unFlagAbuse(user, cc_content, remove_all) @@ -1406,7 +1635,7 @@ def _handle_abuse_flagged_field(form_value, user, cc_content, request): def _handle_voted_field(form_value, cc_content, api_content, request, context): """vote or undo vote on thread/comment""" - signal = thread_voted if cc_content.type == 'thread' else comment_voted + signal = thread_voted if cc_content.type == "thread" else comment_voted signal.send(sender=None, user=context["request"].user, post=cc_content) if form_value: context["cc_requester"].vote(cc_content, "up") @@ -1415,7 +1644,11 @@ def _handle_voted_field(form_value, cc_content, api_content, request, context): context["cc_requester"].unvote(cc_content) api_content["vote_count"] -= 1 track_voted_event( - request, context["course"], cc_content, vote_value="up", undo_vote=not form_value + request, + context["course"], + cc_content, + vote_value="up", + undo_vote=not form_value, ) @@ -1423,7 +1656,7 @@ def _handle_read_field(api_content, form_value, user, cc_content): """ Marks thread as read for the user """ - if form_value and not cc_content['read']: + if form_value and not cc_content["read"]: user.read(cc_content) # When a thread is marked as read, all of its responses and comments # are also marked as read. @@ -1485,29 +1718,56 @@ def create_thread(request, thread_data): if not discussion_open_for_user(course, user): raise DiscussionBlackOutException + # Check if user is banned from discussions + is_user_banned_func = getattr(forum_api, 'is_user_banned', None) + user_banned = False + if ENABLE_DISCUSSION_BAN.is_enabled(course_key) and is_user_banned_func: + try: + user_banned = is_user_banned_func(user, course_key) + except (CommentClientRequestError, CommentClient500Error) as exc: + log.warning( + "Error while checking discussion ban status for user %s in course %s: %s", + getattr(user, "id", None), + course_key, + exc, + ) + if user_banned: + raise PermissionDenied("You are banned from posting in this course's discussions.") + notify_all_learners = thread_data.pop("notify_all_learners", False) context = get_context(course, request) _check_initializable_thread_fields(thread_data, context) discussion_settings = CourseDiscussionSettings.get(course_key) - if ( - "group_id" not in thread_data and - is_commentable_divided(course_key, thread_data.get("topic_id"), discussion_settings) + if "group_id" not in thread_data and is_commentable_divided( + course_key, thread_data.get("topic_id"), discussion_settings ): thread_data = thread_data.copy() thread_data["group_id"] = get_group_id_for_user(user, discussion_settings) serializer = ThreadSerializer(data=thread_data, context=context) actions_form = ThreadActionsForm(thread_data) if not (serializer.is_valid() and actions_form.is_valid()): - raise ValidationError(dict(list(serializer.errors.items()) + list(actions_form.errors.items()))) + raise ValidationError( + dict(list(serializer.errors.items()) + list(actions_form.errors.items())) + ) serializer.save() cc_thread = serializer.instance - thread_created.send(sender=None, user=user, post=cc_thread, notify_all_learners=notify_all_learners) + thread_created.send( + sender=None, user=user, post=cc_thread, notify_all_learners=notify_all_learners + ) api_thread = serializer.data - _do_extra_actions(api_thread, cc_thread, list(thread_data.keys()), actions_form, context, request) + _do_extra_actions( + api_thread, cc_thread, list(thread_data.keys()), actions_form, context, request + ) - track_thread_created_event(request, course, cc_thread, actions_form.cleaned_data["following"], - from_mfe_sidebar, notify_all_learners) + track_thread_created_event( + request, + course, + cc_thread, + actions_form.cleaned_data["following"], + from_mfe_sidebar, + notify_all_learners, + ) return api_thread @@ -1538,6 +1798,22 @@ def create_comment(request, comment_data): if not discussion_open_for_user(course, request.user): raise DiscussionBlackOutException + # Check if user is banned from discussions + is_user_banned_func = getattr(forum_api, 'is_user_banned', None) + user_banned = False + if ENABLE_DISCUSSION_BAN.is_enabled(course.id) and is_user_banned_func: + try: + user_banned = is_user_banned_func(request.user, course.id) + except (CommentClientRequestError, CommentClient500Error) as exc: + log.warning( + "Error while checking discussion ban status for user %s in course %s: %s", + getattr(request.user, "id", None), + course.id, + exc, + ) + if user_banned: + raise PermissionDenied("You are banned from posting in this course's discussions.") + # if a thread is closed; no new comments could be made to it if cc_thread["closed"]: raise PermissionDenied @@ -1546,15 +1822,30 @@ def create_comment(request, comment_data): serializer = CommentSerializer(data=comment_data, context=context) actions_form = CommentActionsForm(comment_data) if not (serializer.is_valid() and actions_form.is_valid()): - raise ValidationError(dict(list(serializer.errors.items()) + list(actions_form.errors.items()))) + raise ValidationError( + dict(list(serializer.errors.items()) + list(actions_form.errors.items())) + ) context["cc_requester"].follow(cc_thread) serializer.save() cc_comment = serializer.instance comment_created.send(sender=None, user=request.user, post=cc_comment) api_comment = serializer.data - _do_extra_actions(api_comment, cc_comment, list(comment_data.keys()), actions_form, context, request) - track_comment_created_event(request, course, cc_comment, cc_thread["commentable_id"], followed=False, - from_mfe_sidebar=from_mfe_sidebar) + _do_extra_actions( + api_comment, + cc_comment, + list(comment_data.keys()), + actions_form, + context, + request, + ) + track_comment_created_event( + request, + course, + cc_comment, + cc_thread["commentable_id"], + followed=False, + from_mfe_sidebar=from_mfe_sidebar, + ) return api_comment @@ -1576,24 +1867,32 @@ def update_thread(request, thread_id, update_data): The updated thread; see discussion.rest_api.views.ThreadViewSet for more detail. """ - cc_thread, context = _get_thread_and_context(request, thread_id, retrieve_kwargs={"with_responses": True}) + cc_thread, context = _get_thread_and_context( + request, thread_id, retrieve_kwargs={"with_responses": True} + ) _check_editable_fields(cc_thread, update_data, context) - serializer = ThreadSerializer(cc_thread, data=update_data, partial=True, context=context) + serializer = ThreadSerializer( + cc_thread, data=update_data, partial=True, context=context + ) actions_form = ThreadActionsForm(update_data) if not (serializer.is_valid() and actions_form.is_valid()): - raise ValidationError(dict(list(serializer.errors.items()) + list(actions_form.errors.items()))) + raise ValidationError( + dict(list(serializer.errors.items()) + list(actions_form.errors.items())) + ) # Only save thread object if some of the edited fields are in the thread data, not extra actions if set(update_data) - set(actions_form.fields): serializer.save() # signal to update Teams when a user edits a thread thread_edited.send(sender=None, user=request.user, post=cc_thread) api_thread = serializer.data - _do_extra_actions(api_thread, cc_thread, list(update_data.keys()), actions_form, context, request) + _do_extra_actions( + api_thread, cc_thread, list(update_data.keys()), actions_form, context, request + ) # always return read as True (and therefore unread_comment_count=0) as reasonably # accurate shortcut, rather than adding additional processing. - api_thread['read'] = True - api_thread['unread_comment_count'] = 0 + api_thread["read"] = True + api_thread["unread_comment_count"] = 0 return api_thread @@ -1628,16 +1927,27 @@ def update_comment(request, comment_id, update_data): """ cc_comment, context = _get_comment_and_context(request, comment_id) _check_editable_fields(cc_comment, update_data, context) - serializer = CommentSerializer(cc_comment, data=update_data, partial=True, context=context) + serializer = CommentSerializer( + cc_comment, data=update_data, partial=True, context=context + ) actions_form = CommentActionsForm(update_data) if not (serializer.is_valid() and actions_form.is_valid()): - raise ValidationError(dict(list(serializer.errors.items()) + list(actions_form.errors.items()))) + raise ValidationError( + dict(list(serializer.errors.items()) + list(actions_form.errors.items())) + ) # Only save comment object if some of the edited fields are in the comment data, not extra actions if set(update_data) - set(actions_form.fields): serializer.save() comment_edited.send(sender=None, user=request.user, post=cc_comment) api_comment = serializer.data - _do_extra_actions(api_comment, cc_comment, list(update_data.keys()), actions_form, context, request) + _do_extra_actions( + api_comment, + cc_comment, + list(update_data.keys()), + actions_form, + context, + request, + ) _handle_comment_signals(update_data, cc_comment, request.user) return api_comment @@ -1671,7 +1981,9 @@ def get_thread(request, thread_id, requested_fields=None, course_id=None): ) if course_id and course_id != cc_thread.course_id: raise ThreadNotFoundError("Thread not found.") - return _serialize_discussion_entities(request, context, [cc_thread], requested_fields, DiscussionEntity.thread)[0] + return _serialize_discussion_entities( + request, context, [cc_thread], requested_fields, DiscussionEntity.thread + )[0] def get_response_comments(request, comment_id, page, page_size, requested_fields=None): @@ -1699,7 +2011,10 @@ def get_response_comments(request, comment_id, page, page_size, requested_fields """ try: cc_comment = Comment(id=comment_id).retrieve() - reverse_order = request.GET.get('reverse_order', False) + reverse_order = request.GET.get("reverse_order", False) + show_deleted = request.GET.get("show_deleted", False) + show_deleted = show_deleted in ["true", "True", True] + cc_thread, context = _get_thread_and_context( request, cc_comment["thread_id"], @@ -1707,10 +2022,13 @@ def get_response_comments(request, comment_id, page, page_size, requested_fields "with_responses": True, "recursive": True, "reverse_order": reverse_order, - } + "show_deleted": show_deleted, + }, ) if cc_thread["thread_type"] == "question": - thread_responses = itertools.chain(cc_thread["endorsed_responses"], cc_thread["non_endorsed_responses"]) + thread_responses = itertools.chain( + cc_thread["endorsed_responses"], cc_thread["non_endorsed_responses"] + ) else: thread_responses = cc_thread["children"] response_comments = [] @@ -1720,16 +2038,35 @@ def get_response_comments(request, comment_id, page, page_size, requested_fields break response_skip = page_size * (page - 1) - paged_response_comments = response_comments[response_skip:(response_skip + page_size)] + paged_response_comments = response_comments[ + response_skip: (response_skip + page_size) + ] if not paged_response_comments and page != 1: raise PageNotFoundError("Page not found (No results on this page).") + if not show_deleted: + paged_response_comments = [ + response + for response in paged_response_comments + if not response.get("is_deleted", False) + ] + else: + if not context["has_moderation_privilege"]: + raise PermissionDenied( + "`show_deleted` can only be set by users with moderation roles." + ) results = _serialize_discussion_entities( - request, context, paged_response_comments, requested_fields, DiscussionEntity.comment + request, + context, + paged_response_comments, + requested_fields, + DiscussionEntity.comment, ) - comments_count = len(response_comments) - num_pages = (comments_count + page_size - 1) // page_size if comments_count else 1 + comments_count = len(paged_response_comments) + num_pages = ( + (comments_count + page_size - 1) // page_size if comments_count else 1 + ) paginator = DiscussionAPIPagination(request, page, num_pages, comments_count) return paginator.get_paginated_response(results) except CommentClientRequestError as err: @@ -1773,16 +2110,20 @@ def get_user_comments( context = get_context(course, request) if flagged and not context["has_moderation_privilege"]: - raise ValidationError("Only privileged users can filter comments by flagged status") + raise ValidationError( + "Only privileged users can filter comments by flagged status" + ) try: - response = Comment.retrieve_all({ - 'user_id': author.id, - 'course_id': str(course_key), - 'flagged': flagged, - 'page': page, - 'per_page': page_size, - }) + response = Comment.retrieve_all( + { + "user_id": author.id, + "course_id": str(course_key), + "flagged": flagged, + "page": page, + "per_page": page_size, + } + ) except CommentClientRequestError as err: raise CommentNotFoundError("Comment not found") from err @@ -1822,7 +2163,7 @@ def delete_thread(request, thread_id): """ cc_thread, context = _get_thread_and_context(request, thread_id) if can_delete(cc_thread, context): - cc_thread.delete() + cc_thread.delete(deleted_by=str(request.user.id)) thread_deleted.send(sender=None, user=request.user, post=cc_thread) track_thread_deleted_event(request, context["course"], cc_thread) else: @@ -1847,7 +2188,7 @@ def delete_comment(request, comment_id): """ cc_comment, context = _get_comment_and_context(request, comment_id) if can_delete(cc_comment, context): - cc_comment.delete() + cc_comment.delete(deleted_by=str(request.user.id)) comment_deleted.send(sender=None, user=request.user, post=cc_comment) track_comment_deleted_event(request, context["course"], cc_comment) else: @@ -1879,7 +2220,10 @@ def get_course_discussion_user_stats( """ course_key = CourseKey.from_string(course_key_str) - is_privileged = has_discussion_privileges(user=request.user, course_id=course_key) or request.user.is_staff + is_privileged = ( + has_discussion_privileges(user=request.user, course_id=course_key) + or request.user.is_staff + ) if is_privileged: order_by = order_by or UserOrdering.BY_FLAGS else: @@ -1888,33 +2232,65 @@ def get_course_discussion_user_stats( raise ValidationError({"order_by": "Invalid value"}) params = { - 'sort_key': str(order_by), - 'page': page, - 'per_page': page_size, + "sort_key": str(order_by), + "page": page, + "per_page": page_size, } comma_separated_usernames = matched_users_count = matched_users_pages = None if username_search_string: - comma_separated_usernames, matched_users_count, matched_users_pages = get_usernames_from_search_string( - course_key, username_search_string, page, page_size + comma_separated_usernames, matched_users_count, matched_users_pages = ( + get_usernames_from_search_string( + course_key, username_search_string, page, page_size + ) ) search_event_data = { - 'query': username_search_string, - 'search_type': 'Learner', - 'page': params.get('page'), - 'sort_key': params.get('sort_key'), - 'total_results': matched_users_count, + "query": username_search_string, + "search_type": "Learner", + "page": params.get("page"), + "sort_key": params.get("sort_key"), + "total_results": matched_users_count, } course = _get_course(course_key, request.user) track_forum_search_event(request, course, search_event_data) + if not comma_separated_usernames: - return DiscussionAPIPagination(request, 0, 1).get_paginated_response({ - "results": [], - }) + return DiscussionAPIPagination(request, 0, 1).get_paginated_response( + { + "results": [], + } + ) - params['usernames'] = comma_separated_usernames + params["usernames"] = comma_separated_usernames course_stats_response = get_course_user_stats(course_key, params) + # Exclude banned users from the learners list + # Get all active bans for this course using forum API + get_banned_usernames = getattr(forum_api, 'get_banned_usernames', None) + banned_usernames = [] + # Only filter banned users if feature flag is enabled + if ENABLE_DISCUSSION_BAN.is_enabled(course_key) and get_banned_usernames is not None: + try: + banned_usernames = get_banned_usernames( + course_id=course_key, + org_key=course_key.org + ) + except Exception: # pylint: disable=broad-except + log.exception( + "Error retrieving banned usernames for course %s; returning unfiltered discussion stats.", + course_key, + ) + banned_usernames = [] + + # Filter out banned users from the stats + if banned_usernames: + course_stats_response["user_stats"] = [ + stats for stats in course_stats_response["user_stats"] + if stats.get('username') not in banned_usernames + ] + # Update count to reflect filtered results + course_stats_response["count"] = len(course_stats_response["user_stats"]) + if comma_separated_usernames: updated_course_stats = add_stats_for_users_with_no_discussion_content( course_stats_response["user_stats"], @@ -1931,71 +2307,502 @@ def get_course_discussion_user_stats( paginator = DiscussionAPIPagination( request, course_stats_response["page"], - matched_users_pages if username_search_string else course_stats_response["num_pages"], - matched_users_count if username_search_string else course_stats_response["count"], + ( + matched_users_pages + if username_search_string + else course_stats_response["num_pages"] + ), + ( + matched_users_count + if username_search_string + else course_stats_response["count"] + ), + ) + return paginator.get_paginated_response( + { + "results": serializer.data, + } ) - return paginator.get_paginated_response({ - "results": serializer.data, - }) def get_users_without_stats( - username_search_string, - course_key, - page_number, - page_size, - request, - is_privileged + username_search_string, course_key, page_number, page_size, request, is_privileged ): """ This return users with no user stats. This function will be deprecated when this ticket DOS-3414 is resolved """ if username_search_string: - comma_separated_usernames, matched_users_count, matched_users_pages = get_usernames_from_search_string( - course_key, username_search_string, page_number, page_size + comma_separated_usernames, matched_users_count, matched_users_pages = ( + get_usernames_from_search_string( + course_key, username_search_string, page_number, page_size + ) ) if not comma_separated_usernames: - return DiscussionAPIPagination(request, 0, 1).get_paginated_response({ - "results": [], - }) + return DiscussionAPIPagination(request, 0, 1).get_paginated_response( + { + "results": [], + } + ) else: - comma_separated_usernames, matched_users_count, matched_users_pages = get_usernames_for_course( - course_key, page_number, page_size + comma_separated_usernames, matched_users_count, matched_users_pages = ( + get_usernames_for_course(course_key, page_number, page_size) ) if comma_separated_usernames: - updated_course_stats = add_stats_for_users_with_null_values([], comma_separated_usernames) + updated_course_stats = add_stats_for_users_with_null_values( + [], comma_separated_usernames + ) - serializer = UserStatsSerializer(updated_course_stats, context={"is_privileged": is_privileged}, many=True) + serializer = UserStatsSerializer( + updated_course_stats, context={"is_privileged": is_privileged}, many=True + ) paginator = DiscussionAPIPagination( request, page_number, matched_users_pages, matched_users_count, ) - return paginator.get_paginated_response({ - "results": serializer.data, - }) + return paginator.get_paginated_response( + { + "results": serializer.data, + } + ) def add_stats_for_users_with_null_values(course_stats, users_in_course): """ Update users stats for users with no discussion stats available in course """ - users_returned_from_api = [user['username'] for user in course_stats] - user_list = users_in_course.split(',') + users_returned_from_api = [user["username"] for user in course_stats] + user_list = users_in_course.split(",") users_with_no_discussion_content = set(user_list) ^ set(users_returned_from_api) updated_course_stats = course_stats for user in users_with_no_discussion_content: - updated_course_stats.append({ - 'username': user, - 'threads': None, - 'replies': None, - 'responses': None, - 'active_flags': None, - 'inactive_flags': None, - }) - updated_course_stats = sorted(updated_course_stats, key=lambda d: len(d['username'])) + updated_course_stats.append( + { + "username": user, + "threads": None, + "replies": None, + "responses": None, + "active_flags": None, + "inactive_flags": None, + } + ) + updated_course_stats = sorted( + updated_course_stats, key=lambda d: len(d["username"]) + ) return updated_course_stats + + +def _get_user_label_function(course_staff_user_ids, moderator_user_ids, ta_user_ids): + """ + Create and return a function that determines user labels based on role. + + Args: + course_staff_user_ids: List of user IDs for course staff + moderator_user_ids: List of user IDs for moderators + ta_user_ids: List of user IDs for TAs + + Returns: + A function that takes a user_id and returns the appropriate label or None + """ + + def get_user_label(user_id): + """Get role label for a user ID.""" + try: + user_id_int = int(user_id) + if user_id_int in course_staff_user_ids: + return "Staff" + elif user_id_int in moderator_user_ids: + return "Moderator" + elif user_id_int in ta_user_ids: + return "Community TA" + except (ValueError, TypeError): + # If user_id has any issues, there's no label to return + pass + return None + + return get_user_label + + +def _process_deleted_thread(thread_data, get_user_label_fn, usernames_set): + """ + Process a single deleted thread into the standardized content item format. + + Args: + thread_data: Raw thread data from forum API + get_user_label_fn: Function to get user labels by user ID + usernames_set: Set to collect usernames for profile image fetch (modified in-place) + + Returns: + dict: Formatted content item for the thread + """ + author_username = thread_data.get("author_username", "") or None + author_id = thread_data.get("author_id", "") + + # If author_username is missing or empty, try to get it from author_id + if not author_username and author_id: + try: + author_user = User.objects.get(id=int(author_id)) + author_username = author_user.username + except (User.DoesNotExist, ValueError): + # If user not found or invalid ID, use placeholder + author_username = None + + deleted_by_id = thread_data.get("deleted_by") + deleted_by_username = None + + # Get deleted_by username + if deleted_by_id: + try: + deleted_user = User.objects.get(id=int(deleted_by_id)) + deleted_by_username = deleted_user.username + usernames_set.add(deleted_by_username) + except (User.DoesNotExist, ValueError): + # If user not found or invalid ID, skip setting deleted fields + pass + + if author_username: + usernames_set.add(author_username) + + # Strip HTML tags from preview + body_text = thread_data.get("body", "") + preview_text = strip_tags(body_text)[:100] if body_text else "" + + thread_id = thread_data.get("_id", thread_data.get("id")) + + # Calculate vote information + votes = thread_data.get("votes", {}) + vote_count = votes.get("up_count", 0) if isinstance(votes, dict) else thread_data.get("vote_count", 0) + + # Get abuse flaggers + abuse_flaggers = thread_data.get("abuse_flaggers", []) + abuse_flagged_count = len(abuse_flaggers) if abuse_flaggers else None + + return { + "id": str(thread_id) + "-thread", + "type": "thread", + "title": thread_data.get("title", ""), + "raw_body": body_text, + "rendered_body": body_text, # For deleted content, just use raw body + "preview_body": preview_text, + "course_id": thread_data.get("course_id", ""), + "author": author_username, + "author_id": thread_data.get("author_id", ""), + "author_label": get_user_label_fn(thread_data.get("author_id")), + "topic_id": thread_data.get("commentable_id", ""), + "commentable_id": thread_data.get("commentable_id", ""), + "group_id": thread_data.get("group_id"), + "group_name": None, # Will be populated by API layer if needed + "created_at": thread_data.get("created_at"), + "updated_at": thread_data.get("updated_at"), + "thread_type": thread_data.get("thread_type", "discussion"), + "anonymous": thread_data.get("anonymous", False), + "anonymous_to_peers": thread_data.get("anonymous_to_peers", False), + "pinned": thread_data.get("pinned", False), + "closed": thread_data.get("closed", False), + "following": False, # Deleted content is not followable + "abuse_flagged": len(abuse_flaggers) > 0 if abuse_flaggers else False, + "abuse_flagged_count": abuse_flagged_count, + "voted": False, # Cannot vote on deleted content + "vote_count": vote_count, + "comment_count": thread_data.get("comment_count", 0), + "unread_comment_count": 0, # Deleted content has no unread count + "comment_list_url": None, + "endorsed_comment_list_url": None, + "non_endorsed_comment_list_url": None, + "read": True, # Treat deleted content as read + "has_endorsed": thread_data.get("endorsed", False), + "editable_fields": [], # Deleted content is not editable + "can_delete": False, # Already deleted + "is_deleted": True, + "deleted_at": thread_data.get("deleted_at"), + "deleted_by": deleted_by_username, + "deleted_by_label": get_user_label_fn(deleted_by_id) if deleted_by_id else None, + "close_reason_code": thread_data.get("close_reason_code"), + "close_reason": None, + "closed_by": thread_data.get("closed_by"), + "closed_by_label": None, + } + + +def _process_deleted_comment(comment_data, get_user_label_fn, usernames_set): + """ + Process a single deleted comment into the standardized content item format. + + Args: + comment_data: Raw comment data from forum API + get_user_label_fn: Function to get user labels by user ID + usernames_set: Set to collect usernames for profile image fetch (modified in-place) + + Returns: + dict: Formatted content item for the comment + """ + author_username = comment_data.get("author_username", "") or None + author_id = comment_data.get("author_id", "") + + # If author_username is missing or empty, try to get it from author_id + if not author_username and author_id: + try: + author_user = User.objects.get(id=int(author_id)) + author_username = author_user.username + except (User.DoesNotExist, ValueError): + # If user not found or invalid ID, use placeholder + author_username = None + + deleted_by_id = comment_data.get("deleted_by") + deleted_by_username = None + + # Get deleted_by username + if deleted_by_id: + try: + deleted_user = User.objects.get(id=int(deleted_by_id)) + deleted_by_username = deleted_user.username + usernames_set.add(deleted_by_username) + except (User.DoesNotExist, ValueError): + # If user not found or invalid ID, skip setting deleted fields + pass + + if author_username: + usernames_set.add(author_username) + + # Determine if this is a response (depth=0) or comment (depth>0) + depth = comment_data.get("depth", 0) + comment_type = "response" if depth == 0 else "comment" + + # Get parent thread title for context + thread_id = comment_data.get("comment_thread_id", "") + thread_title = "" + if thread_id: + try: + parent_thread = Thread(id=thread_id).retrieve() + thread_title = parent_thread.get("title", "") + except Exception: # pylint: disable=broad-exception-caught + pass + + # Strip HTML tags from preview + body_text = comment_data.get("body", "") + preview_text = strip_tags(body_text)[:100] if body_text else "" + + comment_id = comment_data.get("_id", comment_data.get("id")) + + # Calculate vote information + votes = comment_data.get("votes", {}) + vote_count = votes.get("up_count", 0) if isinstance(votes, dict) else comment_data.get("vote_count", 0) + + # Get abuse flaggers + abuse_flaggers = comment_data.get("abuse_flaggers", []) + abuse_flagged_count = len(abuse_flaggers) if abuse_flaggers else None + + return { + "id": str(comment_id) + "-comment", + "type": comment_type, + "raw_body": body_text, + "rendered_body": body_text, # For deleted content, just use raw body + "preview_body": preview_text, + "title": thread_title, # Use parent thread title for comments/responses + "course_id": comment_data.get("course_id", ""), + "author": author_username, + "author_id": comment_data.get("author_id", ""), + "author_label": get_user_label_fn(comment_data.get("author_id")), + "thread_id": str(thread_id), + "comment_thread_id": str(thread_id), + "thread_title": thread_title, + "parent_id": ( + str(comment_data.get("parent_id", "")) + if comment_data.get("parent_id") + else None + ), + "created_at": comment_data.get("created_at"), + "updated_at": comment_data.get("updated_at"), + "depth": depth, + "anonymous": comment_data.get("anonymous", False), + "anonymous_to_peers": comment_data.get("anonymous_to_peers", False), + "endorsed": comment_data.get("endorsed", False), + "endorsed_by": comment_data.get("endorsed_by"), + "endorsed_by_label": None, + "endorsed_at": comment_data.get("endorsed_at"), + "abuse_flagged": len(abuse_flaggers) > 0 if abuse_flaggers else False, + "abuse_flagged_count": abuse_flagged_count, + "voted": False, # Cannot vote on deleted content + "vote_count": vote_count, + "editable_fields": [], # Deleted content is not editable + "can_delete": False, # Already deleted + "child_count": comment_data.get("child_count", 0), + "is_deleted": True, + "deleted_at": comment_data.get("deleted_at"), + "deleted_by": deleted_by_username, + "deleted_by_label": get_user_label_fn(deleted_by_id) if deleted_by_id else None, + } + + +def _add_user_profiles_to_content(deleted_content, usernames_set, request): + """ + Fetch user profile images and add them to each content item. + + Args: + deleted_content: List of content items (modified in-place) + usernames_set: Set of usernames to fetch profile images for + request: Django request object for getting profile images + """ + # Add profile images for all users + username_profile_dict = _get_user_profile_dict( + request, usernames=",".join(usernames_set) + ) + + # Add users dict with profile images to each item + for item in deleted_content: + users_dict = {} + + # Add author profile + author_username = item.get("author") + if author_username and author_username in username_profile_dict: + users_dict[author_username] = _user_profile( + username_profile_dict[author_username] + ) + + # Add deleted_by profile + deleted_by_username = item.get("deleted_by") + if deleted_by_username and deleted_by_username in username_profile_dict: + users_dict[deleted_by_username] = _user_profile( + username_profile_dict[deleted_by_username] + ) + + item["users"] = users_dict + + +def get_deleted_content_for_course( + request, course_id, content_type=None, page=1, per_page=20, author_id=None +): + """ + Retrieve all deleted content (threads, comments) for a course. + + Args: + request: The django request object for getting user profile images + course_id (str): Course identifier + content_type (str, optional): Filter by 'thread' or 'comment'. If None, returns all types. + page (int): Page number for pagination (1-based) + per_page (int): Number of items per page + author_id (str, optional): Filter by author ID + + Returns: + dict: Paginated results with deleted content including author labels and profile images + """ + + import math + + from lms.djangoapps.discussion.rest_api.utils import ( + get_course_staff_users_list, + get_course_ta_users_list, + get_moderator_users_list, + ) + + try: + # Get course and user role information for labels + course_key = CourseKey.from_string(course_id) + course = _get_course(course_key, request.user) + + course_staff_user_ids = get_course_staff_users_list(course.id) + moderator_user_ids = get_moderator_users_list(course.id) + ta_user_ids = get_course_ta_users_list(course.id) + + # Get user label function + get_user_label = _get_user_label_function( + course_staff_user_ids, moderator_user_ids, ta_user_ids + ) + + # Build query parameters for forum API + query_params = { + "course_id": course_id, + "is_deleted": True, # Only get deleted content + "page": page, + "per_page": per_page, + } + + if author_id: + query_params["author_id"] = author_id + + deleted_content = [] + total_count = 0 + usernames_set = set() # Track all usernames for profile image fetch + + # Get deleted threads + if content_type is None or content_type == "thread": + try: + deleted_threads = forum_api.get_deleted_threads_for_course( + course_id=course_id, + page=page if content_type == "thread" else 1, + per_page=per_page if content_type == "thread" else 1000, + author_id=author_id, + ) + for thread_data in deleted_threads.get("threads", []): + content_item = _process_deleted_thread( + thread_data, get_user_label, usernames_set + ) + deleted_content.append(content_item) + + if content_type == "thread": + total_count = deleted_threads.get( + "total_count", len(deleted_content) + ) + except Exception as e: # pylint: disable=broad-exception-caught + log.warning( + "Failed to get deleted threads for course %s: %s", course_id, e + ) + + # Get deleted comments + if content_type is None or content_type == "comment": + try: + deleted_comments = forum_api.get_deleted_comments_for_course( + course_id=course_id, + page=page if content_type == "comment" else 1, + per_page=per_page if content_type == "comment" else 1000, + author_id=author_id, + ) + for comment_data in deleted_comments.get("comments", []): + content_item = _process_deleted_comment( + comment_data, get_user_label, usernames_set + ) + deleted_content.append(content_item) + + if content_type == "comment": + total_count = deleted_comments.get( + "total_count", len(deleted_content) + ) + except Exception as e: # pylint: disable=broad-exception-caught + log.warning( + "Failed to get deleted comments for course %s: %s", course_id, e + ) + + # If getting all content types, handle pagination differently + if content_type is None: + total_count = len(deleted_content) + # Sort by deletion date (most recent first) + deleted_content.sort(key=lambda x: x.get("deleted_at", ""), reverse=True) + + # Apply pagination to combined results + start_idx = (page - 1) * per_page + end_idx = start_idx + per_page + deleted_content = deleted_content[start_idx:end_idx] + + # Add profile images for all users + _add_user_profiles_to_content(deleted_content, usernames_set, request) + + # Calculate pagination info + num_pages = math.ceil(total_count / per_page) if total_count > 0 else 1 + + return { + "results": deleted_content, + "pagination": { + "next": None, # Can be computed if needed + "previous": None, # Can be computed if needed + "count": total_count, + "num_pages": num_pages, + }, + } + + except Exception as e: + log.exception("Error getting deleted content for course %s: %s", course_id, e) + raise diff --git a/lms/djangoapps/discussion/rest_api/emails.py b/lms/djangoapps/discussion/rest_api/emails.py new file mode 100644 index 000000000000..e4ebcc21a567 --- /dev/null +++ b/lms/djangoapps/discussion/rest_api/emails.py @@ -0,0 +1,170 @@ +""" +Email notifications for discussion moderation actions. +""" +import logging + +from django.conf import settings +from django.contrib.auth import get_user_model + +log = logging.getLogger(__name__) +User = get_user_model() + +# Try to import ACE at module level for easier testing +try: + from edx_ace import ace + from edx_ace.recipient import Recipient + from edx_ace.message import Message + ACE_AVAILABLE = True +except ImportError: + ace = None + Recipient = None + Message = None + ACE_AVAILABLE = False + + +def send_ban_escalation_email( + banned_user_id, + moderator_id, + course_id, + scope, + reason, + threads_deleted, + comments_deleted +): + """ + Send email to partner-support when user is banned. + + Uses ACE (Automated Communications Engine) for templated emails if available, + otherwise falls back to Django's email system. + + Args: + banned_user_id: ID of the banned user + moderator_id: ID of the moderator who applied the ban + course_id: Course ID where ban was applied + scope: 'course' or 'organization' + reason: Reason for the ban + threads_deleted: Number of threads deleted + comments_deleted: Number of comments deleted + """ + # Check if email notifications are enabled + if not getattr(settings, 'DISCUSSION_MODERATION_BAN_EMAIL_ENABLED', True): + log.info( + "Ban email notifications disabled by settings. " + "User %s banned in course %s (scope: %s)", + banned_user_id, course_id, scope + ) + return + + try: + banned_user = User.objects.get(id=banned_user_id) + moderator = User.objects.get(id=moderator_id) + + # Get escalation email from settings + escalation_email = getattr( + settings, + 'DISCUSSION_MODERATION_ESCALATION_EMAIL', + 'partner-support@edx.org' + ) + + # Try using ACE first (preferred method for edX) + if ACE_AVAILABLE and ace is not None: + message = Message( + app_label='discussion', + name='ban_escalation', + recipient=Recipient(lms_user_id=None, email_address=escalation_email), + context={ + 'banned_username': banned_user.username, + 'banned_email': banned_user.email, + 'banned_user_id': banned_user_id, + 'moderator_username': moderator.username, + 'moderator_email': moderator.email, + 'moderator_id': moderator_id, + 'course_id': str(course_id), + 'scope': scope, + 'reason': reason or 'No reason provided', + 'threads_deleted': threads_deleted, + 'comments_deleted': comments_deleted, + 'total_deleted': threads_deleted + comments_deleted, + } + ) + + ace.send(message) + log.info( + "Ban escalation email sent via ACE to %s for user %s in course %s", + escalation_email, banned_user.username, course_id + ) + + else: + # Fallback to Django's email system if ACE is not available + from django.core.mail import send_mail + from django.template.loader import render_to_string + from django.template import TemplateDoesNotExist + + context = { + 'banned_username': banned_user.username, + 'banned_email': banned_user.email, + 'banned_user_id': banned_user_id, + 'moderator_username': moderator.username, + 'moderator_email': moderator.email, + 'moderator_id': moderator_id, + 'course_id': str(course_id), + 'scope': scope, + 'reason': reason or 'No reason provided', + 'threads_deleted': threads_deleted, + 'comments_deleted': comments_deleted, + 'total_deleted': threads_deleted + comments_deleted, + } + + # Try to render template, fall back to plain text if template doesn't exist + try: + email_body = render_to_string( + 'discussion/ban_escalation_email.txt', + context + ) + except TemplateDoesNotExist: + # Plain text fallback + banned_user_info = "{} ({})".format(banned_user.username, banned_user.email) + moderator_info = "{} ({})".format(moderator.username, moderator.email) + email_body = """ +A user has been banned from discussions: + +Banned User: {} +Moderator: {} +Course: {} +Scope: {} +Reason: {} +Content Deleted: {} threads, {} comments + +Please review this moderation action and follow up as needed. +""".format( + banned_user_info, + moderator_info, + course_id, + scope, + reason or 'No reason provided', + threads_deleted, + comments_deleted + ) + + subject = f'Discussion Ban Alert: {banned_user.username} in {course_id}' + from_email = getattr(settings, 'DEFAULT_FROM_EMAIL', 'no-reply@example.com') + + send_mail( + subject=subject, + message=email_body, + from_email=from_email, + recipient_list=[escalation_email], + fail_silently=False, + ) + + log.info( + "Ban escalation email sent via Django mail to %s for user %s in course %s", + escalation_email, banned_user.username, course_id + ) + + except User.DoesNotExist as e: + log.error("Failed to send ban escalation email: User not found - %s", str(e)) + raise + except Exception as exc: + log.error("Failed to send ban escalation email: %s", str(exc), exc_info=True) + raise diff --git a/lms/djangoapps/discussion/rest_api/forms.py b/lms/djangoapps/discussion/rest_api/forms.py index 8cc7127645b2..f37543723792 100644 --- a/lms/djangoapps/discussion/rest_api/forms.py +++ b/lms/djangoapps/discussion/rest_api/forms.py @@ -1,6 +1,7 @@ """ Discussion API forms """ + import urllib.parse from django.core.exceptions import ValidationError @@ -22,13 +23,15 @@ class UserOrdering(TextChoices): - BY_ACTIVITY = 'activity' - BY_FLAGS = 'flagged' - BY_RECENT_ACTIVITY = 'recency' + BY_ACTIVITY = "activity" + BY_FLAGS = "flagged" + BY_RECENT_ACTIVITY = "recency" + BY_DELETED = "deleted" class _PaginationForm(Form): """A form that includes pagination fields""" + page = IntegerField(required=False, min_value=1) page_size = IntegerField(required=False, min_value=1) @@ -45,6 +48,7 @@ class ThreadListGetForm(_PaginationForm): """ A form to validate query parameters in the thread list retrieval endpoint """ + EXCLUSIVE_PARAMS = ["topic_id", "text_search", "following"] course_id = CharField() @@ -58,17 +62,22 @@ class ThreadListGetForm(_PaginationForm): ) count_flagged = ExtendedNullBooleanField(required=False) flagged = ExtendedNullBooleanField(required=False) + show_deleted = ExtendedNullBooleanField(required=False) view = ChoiceField( - choices=[(choice, choice) for choice in ["unread", "unanswered", "unresponded"]], + choices=[ + (choice, choice) for choice in ["unread", "unanswered", "unresponded"] + ], required=False, ) order_by = ChoiceField( - choices=[(choice, choice) for choice in ["last_activity_at", "comment_count", "vote_count"]], - required=False + choices=[ + (choice, choice) + for choice in ["last_activity_at", "comment_count", "vote_count"] + ], + required=False, ) order_direction = ChoiceField( - choices=[(choice, choice) for choice in ["desc"]], - required=False + choices=[(choice, choice) for choice in ["desc"]], required=False ) requested_fields = MultiValueField(required=False) @@ -85,14 +94,16 @@ def clean_course_id(self): value = self.cleaned_data["course_id"] try: return CourseLocator.from_string(value) - except InvalidKeyError: - raise ValidationError(f"'{value}' is not a valid course id") # lint-amnesty, pylint: disable=raise-missing-from + except InvalidKeyError as e: + raise ValidationError(f"'{value}' is not a valid course id") from e def clean_following(self): """Validate following""" value = self.cleaned_data["following"] if value is False: # lint-amnesty, pylint: disable=no-else-raise - raise ValidationError("The value of the 'following' parameter must be true.") + raise ValidationError( + "The value of the 'following' parameter must be true." + ) else: return value @@ -115,6 +126,7 @@ class ThreadActionsForm(Form): A form to handle fields in thread creation/update that require separate interactions with the comments service. """ + following = BooleanField(required=False) voted = BooleanField(required=False) abuse_flagged = BooleanField(required=False) @@ -126,17 +138,20 @@ class CommentListGetForm(_PaginationForm): """ A form to validate query parameters in the comment list retrieval endpoint """ + thread_id = CharField() flagged = BooleanField(required=False) endorsed = ExtendedNullBooleanField(required=False) requested_fields = MultiValueField(required=False) merge_question_type_responses = BooleanField(required=False) + show_deleted = ExtendedNullBooleanField(required=False) class UserCommentListGetForm(_PaginationForm): """ A form to validate query parameters in the comment list retrieval endpoint """ + course_id = CharField() flagged = BooleanField(required=False) requested_fields = MultiValueField(required=False) @@ -146,8 +161,8 @@ def clean_course_id(self): value = self.cleaned_data["course_id"] try: return CourseLocator.from_string(value) - except InvalidKeyError: - raise ValidationError(f"'{value}' is not a valid course id") # lint-amnesty, pylint: disable=raise-missing-from + except InvalidKeyError as e: + raise ValidationError(f"'{value}' is not a valid course id") from e class CommentActionsForm(Form): @@ -155,6 +170,7 @@ class CommentActionsForm(Form): A form to handle fields in comment creation/update that require separate interactions with the comments service. """ + voted = BooleanField(required=False) abuse_flagged = BooleanField(required=False) @@ -163,6 +179,7 @@ class CommentGetForm(_PaginationForm): """ A form to validate query parameters in the comment retrieval endpoint """ + requested_fields = MultiValueField(required=False) @@ -170,28 +187,34 @@ class CourseDiscussionSettingsForm(Form): """ A form to validate the fields in the course discussion settings requests. """ + course_id = CharField() def __init__(self, *args, **kwargs): - self.request_user = kwargs.pop('request_user') + self.request_user = kwargs.pop("request_user") super().__init__(*args, **kwargs) def clean_course_id(self): """Validate the 'course_id' value""" - course_id = self.cleaned_data['course_id'] + course_id = self.cleaned_data["course_id"] try: course_key = CourseKey.from_string(course_id) - self.cleaned_data['course'] = get_course_with_access(self.request_user, 'load', course_key) - self.cleaned_data['course_key'] = course_key + self.cleaned_data["course"] = get_course_with_access( + self.request_user, "load", course_key + ) + self.cleaned_data["course_key"] = course_key return course_id - except InvalidKeyError: - raise ValidationError(f"'{str(course_id)}' is not a valid course key") # lint-amnesty, pylint: disable=raise-missing-from + except InvalidKeyError as e: + raise ValidationError( + f"'{str(course_id)}' is not a valid course key" + ) from e class CourseDiscussionRolesForm(CourseDiscussionSettingsForm): """ A form to validate the fields in the course discussion roles requests. """ + ROLE_CHOICES = ( (FORUM_ROLE_MODERATOR, FORUM_ROLE_MODERATOR), (FORUM_ROLE_COMMUNITY_TA, FORUM_ROLE_MODERATOR), @@ -199,20 +222,20 @@ class CourseDiscussionRolesForm(CourseDiscussionSettingsForm): ) rolename = ChoiceField( choices=ROLE_CHOICES, - error_messages={"invalid_choice": "Role '%(value)s' does not exist"} + error_messages={"invalid_choice": "Role '%(value)s' does not exist"}, ) def clean_rolename(self): """Validate the 'rolename' value.""" - rolename = urllib.parse.unquote(self.cleaned_data.get('rolename')) - course_id = self.cleaned_data.get('course_key') + rolename = urllib.parse.unquote(self.cleaned_data.get("rolename")) + course_id = self.cleaned_data.get("course_key") if course_id and rolename: try: role = Role.objects.get(name=rolename, course_id=course_id) except Role.DoesNotExist as err: raise ValidationError(f"Role '{rolename}' does not exist") from err - self.cleaned_data['role'] = role + self.cleaned_data["role"] = role return rolename @@ -220,15 +243,17 @@ class TopicListGetForm(Form): """ Form for the topics API get query parameters. """ + topic_id = CharField(required=False) order_by = ChoiceField(choices=TopicOrdering.choices, required=False) def clean_topic_id(self): topic_ids = self.cleaned_data.get("topic_id", None) - return set(topic_ids.strip(',').split(',')) if topic_ids else None + return set(topic_ids.strip(",").split(",")) if topic_ids else None class CourseActivityStatsForm(_PaginationForm): """Form for validating course activity stats API query parameters""" + order_by = ChoiceField(choices=UserOrdering.choices, required=False) username = CharField(required=False) diff --git a/lms/djangoapps/discussion/rest_api/permissions.py b/lms/djangoapps/discussion/rest_api/permissions.py index cfcea5b32834..a3aa6398915b 100644 --- a/lms/djangoapps/discussion/rest_api/permissions.py +++ b/lms/djangoapps/discussion/rest_api/permissions.py @@ -3,10 +3,11 @@ """ from typing import Dict, Set, Union +from opaque_keys import InvalidKeyError from opaque_keys.edx.keys import CourseKey from rest_framework import permissions -from common.djangoapps.student.models import CourseAccessRole, CourseEnrollment +from common.djangoapps.student.models import CourseEnrollment from common.djangoapps.student.roles import ( CourseInstructorRole, CourseStaffRole, @@ -189,42 +190,137 @@ def has_permission(self, request, view): def can_take_action_on_spam(user, course_id): """ - Returns if the user has access to take action against forum spam posts + Returns if the user has access to take action against forum spam posts. + + Grants access to: + - Global Staff (user.is_staff or GlobalStaff role) + - Course Staff for the specific course + - Course Instructors for the specific course + - Forum Moderators for the specific course + - Forum Administrators for the specific course + Parameters: user: User object course_id: CourseKey or string of course_id + + Returns: + bool: True if user can take action on spam, False otherwise """ - if GlobalStaff().has_user(user): + # Global staff have universal access + if GlobalStaff().has_user(user) or user.is_staff: return True if isinstance(course_id, str): course_id = CourseKey.from_string(course_id) - org_id = course_id.org - course_ids = CourseEnrollment.objects.filter(user=user).values_list('course_id', flat=True) - course_ids = [c_id for c_id in course_ids if c_id.org == org_id] + + # Check if user is Course Staff or Instructor for this specific course + if CourseStaffRole(course_id).has_user(user): + return True + + if CourseInstructorRole(course_id).has_user(user): + return True + + # Check forum moderator/administrator roles for this specific course user_roles = set( Role.objects.filter( users=user, - course_id__in=course_ids, - ).values_list('name', flat=True).distinct() + course_id=course_id, + ).values_list('name', flat=True) ) - if bool(user_roles & {FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR}): - return True - if CourseAccessRole.objects.filter(user=user, course_id__in=course_ids, role__in=["instructor", "staff"]).exists(): + if user_roles & {FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR}: return True + return False class IsAllowedToBulkDelete(permissions.BasePermission): """ - Permission that checks if the user is staff or an admin. + Permission that checks if the user is allowed to perform bulk delete and ban operations. + + Grants access to: + - Global Staff (superusers) + - Course Staff + - Course Instructors + - Forum Moderators + - Forum Administrators + + Denies access to: + - Unauthenticated users + - Regular students + - Community TAs (they can moderate individual posts but not bulk delete) """ def has_permission(self, request, view): - """Returns true if the user can bulk delete posts""" + """ + Returns True if the user can bulk delete posts and ban users. + + For ViewSet actions, course_id may come from: + 1. URL kwargs (view.kwargs.get('course_id')) - for URL path parameters + 2. Request body (request.data.get('course_id')) - for POST request bodies + """ if not request.user.is_authenticated: return False - course_id = view.kwargs.get("course_id") + # Try to get course_id from URL kwargs or request data + course_id = ( + view.kwargs.get("course_id") or + (request.data.get("course_id") if hasattr(request, 'data') else None) + ) + + # If no course_id provided, we can't check permissions yet + # Let the view handle validation of required course_id + if not course_id: + # For safety, only allow global staff to proceed without course_id + return GlobalStaff().has_user(request.user) or request.user.is_staff + return can_take_action_on_spam(request.user, course_id) + + +class IsAllowedToRestore(permissions.BasePermission): + """ + Permission that checks if the user has privileges to restore individual deleted content. + + This permission is intentionally more permissive than IsAllowedToBulkDelete because: + - Restoring individual content is a less risky operation than bulk deletion + - Users who can see deleted content should be able to restore it + - Course-level moderation staff need this capability for day-to-day moderation + + Allowed users (course-level permissions): + - Global staff (platform-wide) + - Course instructors + - Course staff + - Discussion moderators (course-specific) + - Discussion community TAs (course-specific) + - Discussion administrators (course-specific) + """ + + def has_permission(self, request, view): + """Returns true if the user can restore deleted posts""" + if not request.user.is_authenticated: + return False + + # For restore operations, course_id is in request.data, not URL kwargs + course_id = request.data.get("course_id") + if not course_id: + return False + + # Global staff always has permission + if GlobalStaff().has_user(request.user): + return True + + try: + course_key = CourseKey.from_string(course_id) + except InvalidKeyError: + return False + + # Check if user is course staff or instructor + if CourseStaffRole(course_key).has_user(request.user) or \ + CourseInstructorRole(course_key).has_user(request.user): + return True + + # Check if user has discussion privileges (moderator, community TA, administrator) + if has_discussion_privileges(request.user, course_key): + return True + + return False diff --git a/lms/djangoapps/discussion/rest_api/serializers.py b/lms/djangoapps/discussion/rest_api/serializers.py index 9c2668d0b226..1f7f2264cd19 100644 --- a/lms/djangoapps/discussion/rest_api/serializers.py +++ b/lms/djangoapps/discussion/rest_api/serializers.py @@ -1,13 +1,13 @@ """ Discussion API serializers """ + import html import re - -from bs4 import BeautifulSoup from typing import Dict from urllib.parse import urlencode, urlunparse +from bs4 import BeautifulSoup from django.conf import settings from django.contrib.auth import get_user_model from django.core.exceptions import ObjectDoesNotExist, ValidationError @@ -18,8 +18,12 @@ from common.djangoapps.student.models import get_user_by_username_or_email from common.djangoapps.student.roles import GlobalStaff -from lms.djangoapps.discussion.django_comment_client.base.views import track_thread_lock_unlock_event, \ - track_thread_edited_event, track_comment_edited_event, track_forum_response_mark_event +from lms.djangoapps.discussion.django_comment_client.base.views import ( + track_comment_edited_event, + track_forum_response_mark_event, + track_thread_edited_event, + track_thread_lock_unlock_event, +) from lms.djangoapps.discussion.django_comment_client.utils import ( course_discussion_division_enabled, get_group_id_for_user, @@ -35,16 +39,23 @@ from lms.djangoapps.discussion.rest_api.render import render_body from lms.djangoapps.discussion.rest_api.utils import ( get_course_staff_users_list, - get_moderator_users_list, get_course_ta_users_list, + get_moderator_users_list, + get_user_learner_status, ) from openedx.core.djangoapps.discussions.models import DiscussionTopicLink from openedx.core.djangoapps.discussions.utils import get_group_names_by_id from openedx.core.djangoapps.django_comment_common.comment_client.comment import Comment from openedx.core.djangoapps.django_comment_common.comment_client.thread import Thread -from openedx.core.djangoapps.django_comment_common.comment_client.user import User as CommentClientUser -from openedx.core.djangoapps.django_comment_common.comment_client.utils import CommentClientRequestError -from openedx.core.djangoapps.django_comment_common.models import CourseDiscussionSettings +from openedx.core.djangoapps.django_comment_common.comment_client.user import ( + User as CommentClientUser, +) +from openedx.core.djangoapps.django_comment_common.comment_client.utils import ( + CommentClientRequestError, +) +from openedx.core.djangoapps.django_comment_common.models import ( + CourseDiscussionSettings, +) from openedx.core.djangoapps.user_api.accounts.api import get_profile_images from openedx.core.lib.api.serializers import CourseKeyField @@ -58,6 +69,7 @@ class TopicOrdering(TextChoices): """ Enum for the available options for ordering topics. """ + COURSE_STRUCTURE = "course_structure", "Course Structure" ACTIVITY = "activity", "Activity" NAME = "name", "Name" @@ -72,16 +84,22 @@ def get_context(course, request, thread=None): moderator_user_ids = get_moderator_users_list(course.id) ta_user_ids = get_course_ta_users_list(course.id) requester = request.user - cc_requester = CommentClientUser.from_django_user(requester).retrieve(course_id=course.id) + cc_requester = CommentClientUser.from_django_user(requester).retrieve( + course_id=course.id + ) cc_requester["course_id"] = course.id course_discussion_settings = CourseDiscussionSettings.get(course.id) is_global_staff = GlobalStaff().has_user(requester) - has_moderation_privilege = requester.id in moderator_user_ids or requester.id in ta_user_ids or is_global_staff + all_privileged_ids = set(moderator_user_ids) | set(ta_user_ids) | set(course_staff_user_ids) + has_moderation_privilege = requester.id in all_privileged_ids or is_global_staff return { "course": course, + "course_id": course.id, "request": request, "thread": thread, - "discussion_division_enabled": course_discussion_division_enabled(course_discussion_settings), + "discussion_division_enabled": course_discussion_division_enabled( + course_discussion_settings + ), "group_ids_to_names": get_group_names_by_id(course_discussion_settings), "moderator_user_ids": moderator_user_ids, "course_staff_user_ids": course_staff_user_ids, @@ -136,8 +154,8 @@ def _validate_privileged_access(context: Dict) -> bool: Returns: bool: Course exists and the user has privileged access. """ - course = context.get('course', None) - is_requester_privileged = context.get('has_moderation_privilege') + course = context.get("course", None) + is_requester_privileged = context.get("has_moderation_privilege") return course and is_requester_privileged @@ -157,7 +175,7 @@ def filter_spam_urls_from_html(html_string): patterns.append(re.compile(rf"(https?://)?{domain_pattern}", re.IGNORECASE)) for a_tag in soup.find_all("a", href=True): - href = a_tag.get('href') + href = a_tag.get("href") if href: if any(p.search(href) for p in patterns): a_tag.replace_with(a_tag.get_text(strip=True)) @@ -166,7 +184,7 @@ def filter_spam_urls_from_html(html_string): for text_node in soup.find_all(string=True): new_text = text_node for p in patterns: - new_text = p.sub('', new_text) + new_text = p.sub("", new_text) if new_text != text_node: text_node.replace_with(new_text.strip()) is_spam = True @@ -182,6 +200,7 @@ class _ContentSerializer(serializers.Serializer): id = serializers.CharField(read_only=True) # pylint: disable=invalid-name author = serializers.SerializerMethodField() author_label = serializers.SerializerMethodField() + learner_status = serializers.SerializerMethodField() created_at = serializers.CharField(read_only=True) updated_at = serializers.CharField(read_only=True) raw_body = serializers.CharField(source="body", validators=[validate_not_blank]) @@ -194,8 +213,16 @@ class _ContentSerializer(serializers.Serializer): anonymous = serializers.BooleanField(default=False) anonymous_to_peers = serializers.BooleanField(default=False) last_edit = serializers.SerializerMethodField(required=False) - edit_reason_code = serializers.CharField(required=False, validators=[validate_edit_reason_code]) + edit_reason_code = serializers.CharField( + required=False, validators=[validate_edit_reason_code] + ) edit_by_label = serializers.SerializerMethodField(required=False) + is_deleted = serializers.SerializerMethodField(read_only=True) + deleted_at = serializers.SerializerMethodField(read_only=True) + deleted_by = serializers.SerializerMethodField(read_only=True) + deleted_by_label = serializers.SerializerMethodField(read_only=True) + is_author_banned = serializers.SerializerMethodField(read_only=True) + author_ban_scope = serializers.SerializerMethodField(read_only=True) non_updatable_fields = set() @@ -217,7 +244,10 @@ def _is_user_privileged(self, user_id): Returns a boolean indicating whether the given user_id identifies a privileged user. """ - return user_id in self.context["moderator_user_ids"] or user_id in self.context["ta_user_ids"] + return ( + user_id in self.context["moderator_user_ids"] + or user_id in self.context["ta_user_ids"] + ) def _is_anonymous(self, obj): """ @@ -225,13 +255,13 @@ def _is_anonymous(self, obj): the requester. """ user_id = self.context["request"].user.id - is_user_staff = user_id in self.context["moderator_user_ids"] or user_id in self.context["ta_user_ids"] - - return ( - obj["anonymous"] or - obj["anonymous_to_peers"] and not is_user_staff + is_user_staff = ( + user_id in self.context["moderator_user_ids"] + or user_id in self.context["ta_user_ids"] ) + return obj["anonymous"] or obj["anonymous_to_peers"] and not is_user_staff + def get_author(self, obj): """ Returns the author's username, or None if the content is anonymous. @@ -248,10 +278,9 @@ def _get_user_label(self, user_id): is_ta = user_id in self.context["ta_user_ids"] return ( - "Staff" if is_staff else - "Moderator" if is_moderator else - "Community TA" if is_ta else - None + "Staff" + if is_staff + else "Moderator" if is_moderator else "Community TA" if is_ta else None ) def _get_user_label_from_username(self, username): @@ -275,13 +304,35 @@ def get_author_label(self, obj): user_id = int(obj["user_id"]) return self._get_user_label(user_id) + def get_learner_status(self, obj): + """ + Get the learner status for the discussion post author. + Returns one of: "anonymous", "staff", "new", "regular" + """ + # Skip for anonymous content + if self._is_anonymous(obj) or obj.get("user_id") is None: + return "anonymous" + + try: + user = User.objects.get(id=int(obj["user_id"])) + except (User.DoesNotExist, ValueError): + return "anonymous" + + course = self.context.get("course") + if not course: + return "anonymous" + + return get_user_learner_status(user, course.id) + def get_rendered_body(self, obj): """ Returns the rendered body content. """ if self._rendered_body is None: self._rendered_body = render_body(obj["body"]) - self._rendered_body, is_spam = filter_spam_urls_from_html(self._rendered_body) + self._rendered_body, is_spam = filter_spam_urls_from_html( + self._rendered_body + ) if is_spam and settings.CONTENT_FOR_SPAM_POSTS: self._rendered_body = settings.CONTENT_FOR_SPAM_POSTS return self._rendered_body @@ -293,8 +344,9 @@ def get_abuse_flagged(self, obj): """ total_abuse_flaggers = len(obj.get("abuse_flaggers", [])) return ( - self.context["has_moderation_privilege"] and total_abuse_flaggers > 0 or - self.context["cc_requester"]["id"] in obj.get("abuse_flaggers", []) + self.context["has_moderation_privilege"] + and total_abuse_flaggers > 0 + or self.context["cc_requester"]["id"] in obj.get("abuse_flaggers", []) ) def get_voted(self, obj): @@ -327,7 +379,7 @@ def get_last_edit(self, obj): Returns information about the last edit for this content for privileged users. """ - is_user_author = str(obj['user_id']) == str(self.context['request'].user.id) + is_user_author = str(obj["user_id"]) == str(self.context["request"].user.id) if not (_validate_privileged_access(self.context) or is_user_author): return None edit_history = obj.get("edit_history") @@ -343,12 +395,233 @@ def get_edit_by_label(self, obj): """ Returns the role label for the last edit user. """ - is_user_author = str(obj['user_id']) == str(self.context['request'].user.id) + is_user_author = str(obj["user_id"]) == str(self.context["request"].user.id) is_user_privileged = _validate_privileged_access(self.context) edit_history = obj.get("edit_history") if (is_user_author or is_user_privileged) and edit_history: last_edit = edit_history[-1] - return self._get_user_label_from_username(last_edit.get('editor_username')) + return self._get_user_label_from_username(last_edit.get("editor_username")) + + def get_is_deleted(self, obj): + """ + Returns the is_deleted status for privileged users or content authors. + """ + is_user_author = str(obj["user_id"]) == str(self.context["request"].user.id) + if not (_validate_privileged_access(self.context) or is_user_author): + return None + return obj.get("is_deleted", False) + + def get_deleted_at(self, obj): + """ + Returns the deletion timestamp for privileged users or content authors. + """ + is_user_author = str(obj["user_id"]) == str(self.context["request"].user.id) + if not (_validate_privileged_access(self.context) or is_user_author): + return None + return obj.get("deleted_at") + + def get_deleted_by(self, obj): + """ + Returns the username of the user who deleted this content for privileged users or content authors. + """ + is_user_author = str(obj["user_id"]) == str(self.context["request"].user.id) + if not (_validate_privileged_access(self.context) or is_user_author): + return None + deleted_by_id = obj.get("deleted_by") + if deleted_by_id: + try: + user = User.objects.get(id=int(deleted_by_id)) + return user.username + except (User.DoesNotExist, ValueError): + return None + return None + + def get_deleted_by_label(self, obj): + """ + Returns the role label for the user who deleted this content for privileged users only. + """ + if not _validate_privileged_access(self.context): + return None + deleted_by_id = obj.get("deleted_by") + if deleted_by_id: + try: + return self._get_user_label(int(deleted_by_id)) + except (ValueError, TypeError): + return None + return None + + def _get_author_ban_cache_key(self, course_id, user_id): + """Build a stable cache key for author ban lookups.""" + return (str(course_id), int(user_id)) + + def _get_author_from_cache(self, user_id): + """Fetch author from per-request cache or database.""" + user_cache = self.context.setdefault("_author_ban_user_cache", {}) + if user_id not in user_cache: + try: + user_cache[user_id] = User.objects.get(id=user_id) + except User.DoesNotExist: + user_cache[user_id] = None + return user_cache[user_id] + + def get_is_author_banned(self, obj): + """ + Returns True if the content author is banned from discussions. + Returns False for anonymous content or if ban check fails. + """ + from forum import api as forum_api + from lms.djangoapps.discussion.toggles import ENABLE_DISCUSSION_BAN + + # Skip for anonymous content + if self._is_anonymous(obj) or obj.get("user_id") is None: + return False + + # Skip if ban function not available + is_user_banned_func = getattr(forum_api, 'is_user_banned', None) + if not is_user_banned_func: + return False + + # Skip if feature flag is not enabled + course_id = self.context.get("course_id") + if not course_id or not ENABLE_DISCUSSION_BAN.is_enabled(course_id): + return False + + try: + user_id = int(obj["user_id"]) + except (ValueError, TypeError): + return False + + cache_key = self._get_author_ban_cache_key(course_id, user_id) + ban_status_cache = self.context.setdefault("_author_ban_status_cache", {}) + if cache_key in ban_status_cache: + return ban_status_cache[cache_key] + + try: + user = self._get_author_from_cache(user_id) + if not user: + ban_status_cache[cache_key] = False + return False + + is_banned = is_user_banned_func(user, course_id) + ban_status_cache[cache_key] = is_banned + return is_banned + except (User.DoesNotExist, ValueError, Exception): # pylint: disable=broad-except + ban_status_cache[cache_key] = False + + return False + + def get_author_ban_scope(self, obj): + """ + Returns the scope of the author's ban ('course' or 'organization'). + Returns None for anonymous content, unbanned users, or if check fails. + """ + from forum import api as forum_api + from lms.djangoapps.discussion.toggles import ENABLE_DISCUSSION_BAN + import logging + logger = logging.getLogger(__name__) + + # Skip for anonymous content + if self._is_anonymous(obj) or obj.get("user_id") is None: + return None + + # Skip if required functions not available + is_user_banned_func = getattr(forum_api, 'is_user_banned', None) + get_user_bans_func = getattr(forum_api, 'get_user_bans', None) + if not is_user_banned_func: + return None + + # Skip if feature flag is not enabled + course_id = self.context.get("course_id") + if not course_id or not ENABLE_DISCUSSION_BAN.is_enabled(course_id): + return None + + try: + user_id = int(obj["user_id"]) + except (ValueError, TypeError): + return None + + cache_key = self._get_author_ban_cache_key(course_id, user_id) + ban_scope_cache = self.context.setdefault("_author_ban_scope_cache", {}) + if cache_key in ban_scope_cache: + return ban_scope_cache[cache_key] + + ban_status_cache = self.context.setdefault("_author_ban_status_cache", {}) + + try: + user = self._get_author_from_cache(user_id) + if not user: + ban_scope_cache[cache_key] = None + return None + + if not course_id: + ban_scope_cache[cache_key] = None + return None + + # First check if user is banned at all + user_banned = ban_status_cache.get(cache_key) + if user_banned is None: + user_banned = is_user_banned_func(user, course_id) + ban_status_cache[cache_key] = user_banned + + if not user_banned: + ban_scope_cache[cache_key] = None + return None + + # Try to get all active bans for this user and course + if get_user_bans_func: + try: + bans = get_user_bans_func(user=user, course_id=course_id) + # Check for organization-level ban first (higher precedence) + for ban in bans: + if ban.get('is_active') and ban.get('scope') == 'organization': + ban_scope_cache[cache_key] = 'organization' + return 'organization' + # Then check for course-level ban + for ban in bans: + if ban.get('is_active') and ban.get('scope') == 'course': + ban_scope_cache[cache_key] = 'course' + return 'course' + except Exception as e: # pylint: disable=broad-except + logger.debug( + "Unable to fetch ban list for ban-scope detection. course_id=%s user_id=%s error=%s", + course_id, + obj.get("user_id"), + e, + ) + + # Fallback: Try checking each scope individually using is_user_banned + # check_org parameter: True = include org checks, False = course-only + try: + # Check course-only (check_org=False means don't check org) + course_only = is_user_banned_func(user, course_id, check_org=False) + + # If course-only check returns False but user IS banned, must be org-banned + if not course_only: + ban_scope_cache[cache_key] = 'organization' + return 'organization' + + # If course-only check returns True, it's course-level ban + ban_scope_cache[cache_key] = 'course' + return 'course' + except TypeError as e: + # check_org parameter might not exist in older versions + logger.debug( + "check_org parameter unsupported during ban-scope detection. course_id=%s user_id=%s error=%s", + course_id, + obj.get("user_id"), + e, + ) + + except (User.DoesNotExist, ValueError, Exception) as e: # pylint: disable=broad-except + logger.warning( + "Unable to determine author ban scope. course_id=%s user_id=%s error=%s", + self.context.get("course_id"), + obj.get("user_id"), + e, + ) + + ban_scope_cache[cache_key] = None + return None class ThreadSerializer(_ContentSerializer): @@ -359,13 +632,15 @@ class ThreadSerializer(_ContentSerializer): not had retrieve() called, because of the interaction between DRF's attempts at introspection and Thread's __getattr__. """ + course_id = serializers.CharField() - topic_id = serializers.CharField(source="commentable_id", validators=[validate_not_blank]) + topic_id = serializers.CharField( + source="commentable_id", validators=[validate_not_blank] + ) group_id = serializers.IntegerField(required=False, allow_null=True) group_name = serializers.SerializerMethodField() type = serializers.ChoiceField( - source="thread_type", - choices=[(val, val) for val in ["discussion", "question"]] + source="thread_type", choices=[(val, val) for val in ["discussion", "question"]] ) preview_body = serializers.SerializerMethodField() abuse_flagged_count = serializers.SerializerMethodField(required=False) @@ -380,8 +655,12 @@ class ThreadSerializer(_ContentSerializer): non_endorsed_comment_list_url = serializers.SerializerMethodField() read = serializers.BooleanField(required=False) has_endorsed = serializers.BooleanField(source="endorsed", read_only=True) - response_count = serializers.IntegerField(source="resp_total", read_only=True, required=False) - close_reason_code = serializers.CharField(required=False, validators=[validate_close_reason_code]) + response_count = serializers.IntegerField( + source="resp_total", read_only=True, required=False + ) + close_reason_code = serializers.CharField( + required=False, validators=[validate_close_reason_code] + ) close_reason = serializers.SerializerMethodField() closed_by = serializers.SerializerMethodField() closed_by_label = serializers.SerializerMethodField(required=False) @@ -427,9 +706,8 @@ def get_comment_list_url(self, obj, endorsed=None): Returns the URL to retrieve the thread's comments, optionally including the endorsed query parameter. """ - if ( - (obj["thread_type"] == "question" and endorsed is None) or - (obj["thread_type"] == "discussion" and endorsed is not None) + if (obj["thread_type"] == "question" and endorsed is None) or ( + obj["thread_type"] == "discussion" and endorsed is not None ): return None path = reverse("comment-list") @@ -473,13 +751,17 @@ def get_preview_body(self, obj): """ Returns a cleaned version of the thread's body to display in a preview capacity. """ - return strip_tags(self.get_rendered_body(obj)).replace('\n', ' ').replace(' ', ' ') + return ( + strip_tags(self.get_rendered_body(obj)) + .replace("\n", " ") + .replace(" ", " ") + ) def get_close_reason(self, obj): """ Returns the reason for which the thread was closed. """ - is_user_author = str(obj['user_id']) == str(self.context['request'].user.id) + is_user_author = str(obj["user_id"]) == str(self.context["request"].user.id) if not (_validate_privileged_access(self.context) or is_user_author): return None reason_code = obj.get("close_reason_code") @@ -490,7 +772,7 @@ def get_closed_by(self, obj): Returns the username of the moderator who closed this thread, only to other privileged users and author. """ - is_user_author = str(obj['user_id']) == str(self.context['request'].user.id) + is_user_author = str(obj["user_id"]) == str(self.context["request"].user.id) if _validate_privileged_access(self.context) or is_user_author: return obj.get("closed_by") @@ -498,7 +780,7 @@ def get_closed_by_label(self, obj): """ Returns the role label for the user who closed the post. """ - is_user_author = str(obj['user_id']) == str(self.context['request'].user.id) + is_user_author = str(obj["user_id"]) == str(self.context["request"].user.id) if is_user_author or _validate_privileged_access(self.context): return self._get_user_label_from_username(obj.get("closed_by")) @@ -513,18 +795,31 @@ def update(self, instance, validated_data): requesting_user_id = self.context["cc_requester"]["id"] if key == "closed" and val: instance["closing_user_id"] = requesting_user_id - track_thread_lock_unlock_event(self.context['request'], self.context['course'], - instance, validated_data.get('close_reason_code')) + track_thread_lock_unlock_event( + self.context["request"], + self.context["course"], + instance, + validated_data.get("close_reason_code"), + ) if key == "closed" and not val: instance["closing_user_id"] = requesting_user_id - track_thread_lock_unlock_event(self.context['request'], self.context['course'], - instance, validated_data.get('close_reason_code'), locked=False) + track_thread_lock_unlock_event( + self.context["request"], + self.context["course"], + instance, + validated_data.get("close_reason_code"), + locked=False, + ) if key == "body" and val: instance["editing_user_id"] = requesting_user_id - track_thread_edited_event(self.context['request'], self.context['course'], - instance, validated_data.get('edit_reason_code')) + track_thread_edited_event( + self.context["request"], + self.context["course"], + instance, + validated_data.get("edit_reason_code"), + ) instance.save() return instance @@ -537,6 +832,7 @@ class CommentSerializer(_ContentSerializer): not had retrieve() called, because of the interaction between DRF's attempts at introspection and Comment's __getattr__. """ + thread_id = serializers.CharField() parent_id = serializers.CharField(required=False, allow_null=True) endorsed = serializers.BooleanField(required=False) @@ -551,7 +847,7 @@ class CommentSerializer(_ContentSerializer): non_updatable_fields = NON_UPDATABLE_COMMENT_FIELDS def __init__(self, *args, **kwargs): - remove_fields = kwargs.pop('remove_fields', None) + remove_fields = kwargs.pop("remove_fields", None) super().__init__(*args, **kwargs) if remove_fields: @@ -573,8 +869,8 @@ def get_endorsed_by(self, obj): # Avoid revealing the identity of an anonymous non-staff question # author who has endorsed a comment in the thread if not ( - self._is_anonymous(self.context["thread"]) and - not self._is_user_privileged(endorser_id) + self._is_anonymous(self.context["thread"]) + and not self._is_user_privileged(endorser_id) ): return User.objects.get(id=endorser_id).username return None @@ -616,7 +912,7 @@ def to_representation(self, data): # Django Rest Framework v3 no longer includes None values # in the representation. To maintain the previous behavior, # we do this manually instead. - if 'parent_id' not in data: + if "parent_id" not in data: data["parent_id"] = None return data @@ -658,7 +954,7 @@ def create(self, validated_data): comment = Comment( course_id=self.context["thread"]["course_id"], user_id=self.context["cc_requester"]["id"], - **validated_data + **validated_data, ) comment.save() return comment @@ -671,12 +967,18 @@ def update(self, instance, validated_data): # endorsement_user_id on update requesting_user_id = self.context["cc_requester"]["id"] if key == "endorsed": - track_forum_response_mark_event(self.context['request'], self.context['course'], instance, val) + track_forum_response_mark_event( + self.context["request"], self.context["course"], instance, val + ) instance["endorsement_user_id"] = requesting_user_id if key == "body" and val: instance["editing_user_id"] = requesting_user_id - track_comment_edited_event(self.context['request'], self.context['course'], - instance, validated_data.get('edit_reason_code')) + track_comment_edited_event( + self.context["request"], + self.context["course"], + instance, + validated_data.get("edit_reason_code"), + ) instance.save() return instance @@ -686,6 +988,7 @@ class DiscussionTopicSerializer(serializers.Serializer): """ Serializer for DiscussionTopic """ + id = serializers.CharField(read_only=True) # pylint: disable=invalid-name name = serializers.CharField(read_only=True) thread_list_url = serializers.CharField(read_only=True) @@ -715,10 +1018,11 @@ class DiscussionTopicSerializerV2(serializers.Serializer): """ Serializer for new style topics. """ + id = serializers.CharField( # pylint: disable=invalid-name read_only=True, source="external_id", - help_text="Provider-specific unique id for the topic" + help_text="Provider-specific unique id for the topic", ) usage_key = serializers.CharField( read_only=True, @@ -742,10 +1046,13 @@ def get_thread_counts(self, obj: DiscussionTopicLink) -> Dict[str, int]: """ Get thread counts from provided context """ - return self.context['thread_counts'].get(obj.external_id, { - "discussion": 0, - "question": 0, - }) + return self.context["thread_counts"].get( + obj.external_id, + { + "discussion": 0, + "question": 0, + }, + ) class DiscussionRolesSerializer(serializers.Serializer): @@ -753,10 +1060,7 @@ class DiscussionRolesSerializer(serializers.Serializer): Serializer for course discussion roles. """ - ACTION_CHOICES = ( - ('allow', 'allow'), - ('revoke', 'revoke') - ) + ACTION_CHOICES = (("allow", "allow"), ("revoke", "revoke")) action = serializers.ChoiceField(ACTION_CHOICES) user_id = serializers.CharField() @@ -777,14 +1081,16 @@ def validate_user_id(self, user_id): self.user = get_user_by_username_or_email(user_id) return user_id except User.DoesNotExist as err: - raise ValidationError(f"'{user_id}' is not a valid student identifier") from err + raise ValidationError( + f"'{user_id}' is not a valid student identifier" + ) from err def validate(self, attrs): """Validate the data at an object level.""" # Store the user object to avoid fetching it again. - if hasattr(self, 'user'): - attrs['user'] = self.user + if hasattr(self, "user"): + attrs["user"] = self.user return attrs def create(self, validated_data): @@ -802,6 +1108,7 @@ class DiscussionRolesMemberSerializer(serializers.Serializer): """ Serializer for course discussion roles member data. """ + username = serializers.CharField() email = serializers.EmailField() first_name = serializers.CharField() @@ -810,7 +1117,7 @@ class DiscussionRolesMemberSerializer(serializers.Serializer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.course_discussion_settings = self.context['course_discussion_settings'] + self.course_discussion_settings = self.context["course_discussion_settings"] def get_group_name(self, instance): """Return the group name of the user.""" @@ -833,6 +1140,7 @@ class DiscussionRolesListSerializer(serializers.Serializer): """ Serializer for course discussion roles member list. """ + course_id = serializers.CharField() results = serializers.SerializerMethodField() division_scheme = serializers.SerializerMethodField() @@ -840,15 +1148,17 @@ class DiscussionRolesListSerializer(serializers.Serializer): def get_results(self, obj): """Return the nested serializer data representing a list of member users.""" context = { - 'course_id': obj['course_id'], - 'course_discussion_settings': self.context['course_discussion_settings'] + "course_id": obj["course_id"], + "course_discussion_settings": self.context["course_discussion_settings"], } - serializer = DiscussionRolesMemberSerializer(obj['users'], context=context, many=True) + serializer = DiscussionRolesMemberSerializer( + obj["users"], context=context, many=True + ) return serializer.data def get_division_scheme(self, obj): # pylint: disable=unused-argument """Return the division scheme for the course.""" - return self.context['course_discussion_settings'].division_scheme + return self.context["course_discussion_settings"].division_scheme def create(self, validated_data): """ @@ -865,9 +1175,13 @@ class UserStatsSerializer(serializers.Serializer): """ Serializer for course user stats. """ + threads = serializers.IntegerField() replies = serializers.IntegerField() responses = serializers.IntegerField() + deleted_threads = serializers.IntegerField(required=False, default=0) + deleted_replies = serializers.IntegerField(required=False, default=0) + deleted_responses = serializers.IntegerField(required=False, default=0) active_flags = serializers.IntegerField() inactive_flags = serializers.IntegerField() username = serializers.CharField() @@ -885,27 +1199,36 @@ class BlackoutDateSerializer(serializers.Serializer): """ Serializer for blackout dates. """ - start = serializers.DateTimeField(help_text="The ISO 8601 timestamp for the start of the blackout period") - end = serializers.DateTimeField(help_text="The ISO 8601 timestamp for the end of the blackout period") + + start = serializers.DateTimeField( + help_text="The ISO 8601 timestamp for the start of the blackout period" + ) + end = serializers.DateTimeField( + help_text="The ISO 8601 timestamp for the end of the blackout period" + ) class ReasonCodeSeralizer(serializers.Serializer): """ Serializer for reason codes. """ + code = serializers.CharField(help_text="A code for the an edit or close reason") - label = serializers.CharField(help_text="A user-friendly name text for the close or edit reason") + label = serializers.CharField( + help_text="A user-friendly name text for the close or edit reason" + ) class CourseMetadataSerailizer(serializers.Serializer): """ Serializer for course metadata. """ + id = CourseKeyField(help_text="The identifier of the course") blackouts = serializers.ListField( child=BlackoutDateSerializer(), help_text="A list of objects representing blackout periods " - "(during which discussions are read-only except for privileged users)." + "(during which discussions are read-only except for privileged users).", ) thread_list_url = serializers.URLField( help_text="The URL of the list of all threads in the course.", @@ -913,7 +1236,9 @@ class CourseMetadataSerailizer(serializers.Serializer): following_thread_list_url = serializers.URLField( help_text="thread_list_url with parameter following=True", ) - topics_url = serializers.URLField(help_text="The URL of the topic listing for the course.") + topics_url = serializers.URLField( + help_text="The URL of the topic listing for the course." + ) allow_anonymous = serializers.BooleanField( help_text="A boolean indicating whether anonymous posts are allowed or not.", ) @@ -944,3 +1269,141 @@ class CourseMetadataSerailizer(serializers.Serializer): child=ReasonCodeSeralizer(), help_text="A list of reasons that can be specified by moderators for editing a post, response, or comment", ) + + +class BulkDeleteBanRequestSerializer(serializers.Serializer): + """ + Request payload for bulk delete + ban action. + + Accepts either user_id (for programmatic access) or username (for UI/human convenience). + Internally normalizes to user_id before processing. + """ + + user_id = serializers.IntegerField( + required=False, + help_text="User ID to ban. Either user_id or username must be provided." + ) + username = serializers.CharField( + required=False, + max_length=150, + help_text="Username to ban. Converted to user_id internally. Either user_id or username must be provided." + ) + course_id = serializers.CharField(max_length=255, required=True) + ban_user = serializers.BooleanField(default=False) + ban_scope = serializers.ChoiceField( + choices=['course', 'organization'], + default='course', + help_text="Scope of the ban: 'course' for course-level or 'organization' for organization-level" + ) + reason = serializers.CharField( + required=False, + allow_blank=True, + max_length=1000 + ) + + def validate(self, data): + """ + Validate and normalize user identification. + + - Ensures either user_id or username is provided + - Converts username to user_id if needed + - Validates ban requirements (reason, permissions) + """ + # Validate that either user_id or username is provided + if not data.get('user_id') and not data.get('username'): + raise serializers.ValidationError({ + 'user_id': "Either user_id or username must be provided." + }) + + # Normalize username to user_id for internal processing + # This allows the view/task to always work with user_id + if data.get('username') and not data.get('user_id'): + try: + user = User.objects.get(username=data['username']) + data['user_id'] = user.id + # Keep username for logging/audit purposes + data['resolved_username'] = user.username + except User.DoesNotExist as exc: + raise serializers.ValidationError({ + 'username': f"User with username '{data['username']}' does not exist." + }) from exc + elif data.get('user_id'): + # If user_id provided directly, resolve username for consistency + try: + user = User.objects.get(id=data['user_id']) + data['resolved_username'] = user.username + except User.DoesNotExist as exc: + raise serializers.ValidationError({ + 'user_id': f"User with ID {data['user_id']} does not exist." + }) from exc + + if data.get('ban_user'): + reason = data.get('reason', '').strip() + if not reason: + raise serializers.ValidationError({ + 'reason': "Reason is required when banning a user." + }) + + # Validate that organization-level bans require elevated permissions + # only when a ban is requested. + if data.get('ban_user') and data.get('ban_scope') == 'organization': + request = self.context.get('request') + if request and not ( + GlobalStaff().has_user(request.user) or request.user.is_staff + ): + raise serializers.ValidationError({ + 'ban_scope': "Organization-level bans require global staff permissions." + }) + + return data + + +class BanUserRequestSerializer(serializers.Serializer): + """ + Request payload for standalone ban action (without bulk delete). + + For direct ban from UI moderation actions. + """ + + user_id = serializers.IntegerField( + required=False, + help_text="User ID to ban. Either user_id or username must be provided." + ) + username = serializers.CharField( + required=False, + max_length=150, + help_text="Username to ban. Converted to user_id internally. Either user_id or username must be provided." + ) + course_id = serializers.CharField( + max_length=255, + required=True, + help_text="Course ID for course-level bans or org context for organization-level bans" + ) + scope = serializers.ChoiceField( + choices=['course', 'organization'], + default='course', + help_text="Scope of the ban: 'course' for course-level or 'organization' for organization-level" + ) + reason = serializers.CharField( + required=False, + allow_blank=True, + max_length=1000, + help_text="Reason for the ban (optional)" + ) + + def validate(self, data): + """ + Validate and normalize user identification. + """ + # Validate that either user_id or username is provided + if not data.get('user_id') and not data.get('username'): + raise serializers.ValidationError({ + 'user_id': "Either user_id or username must be provided." + }) + + # Normalize username to user_id if provided (view will validate existence) + if data.get('username') and not data.get('user_id'): + # Don't validate user existence here - let the view return 404 + # Just record the username for the view to resolve + data['lookup_username'] = data['username'] + return data diff --git a/lms/djangoapps/discussion/rest_api/tasks.py b/lms/djangoapps/discussion/rest_api/tasks.py index cd725a3513dc..3bc96b654288 100644 --- a/lms/djangoapps/discussion/rest_api/tasks.py +++ b/lms/djangoapps/discussion/rest_api/tasks.py @@ -1,32 +1,36 @@ """ Contain celery tasks """ + import logging from celery import shared_task from django.contrib.auth import get_user_model from edx_django_utils.monitoring import set_code_owner_attribute -from opaque_keys.edx.locator import CourseKey from eventtracking import tracker +from opaque_keys.edx.keys import CourseKey -from common.djangoapps.student.roles import CourseStaffRole, CourseInstructorRole +from common.djangoapps.student.roles import CourseInstructorRole, CourseStaffRole from common.djangoapps.track import segment from lms.djangoapps.courseware.courses import get_course_with_access from lms.djangoapps.discussion.django_comment_client.utils import get_user_role_names -from lms.djangoapps.discussion.rest_api.discussions_notifications import DiscussionNotificationSender +from lms.djangoapps.discussion.rest_api.discussions_notifications import ( + DiscussionNotificationSender, +) from lms.djangoapps.discussion.rest_api.utils import can_user_notify_all_learners from openedx.core.djangoapps.django_comment_common.comment_client import Comment from openedx.core.djangoapps.django_comment_common.comment_client.thread import Thread from openedx.core.djangoapps.notifications.config.waffle import ENABLE_NOTIFICATIONS - User = get_user_model() log = logging.getLogger(__name__) @shared_task @set_code_owner_attribute -def send_thread_created_notification(thread_id, course_key_str, user_id, notify_all_learners=False): +def send_thread_created_notification( + thread_id, course_key_str, user_id, notify_all_learners=False +): """ Send notification when a new thread is created """ @@ -40,17 +44,21 @@ def send_thread_created_notification(thread_id, course_key_str, user_id, notify_ is_course_staff = CourseStaffRole(course_key).has_user(user) is_course_admin = CourseInstructorRole(course_key).has_user(user) user_roles = get_user_role_names(user, course_key) - if not can_user_notify_all_learners(user_roles, is_course_staff, is_course_admin): + if not can_user_notify_all_learners( + user_roles, is_course_staff, is_course_admin + ): return - course = get_course_with_access(user, 'load', course_key, check_if_enrolled=True) + course = get_course_with_access(user, "load", course_key, check_if_enrolled=True) notification_sender = DiscussionNotificationSender(thread, course, user) notification_sender.send_new_thread_created_notification(notify_all_learners) @shared_task @set_code_owner_attribute -def send_response_notifications(thread_id, course_key_str, user_id, comment_id, parent_id=None): +def send_response_notifications( + thread_id, course_key_str, user_id, comment_id, parent_id=None +): """ Send notifications to users who are subscribed to the thread. """ @@ -59,8 +67,10 @@ def send_response_notifications(thread_id, course_key_str, user_id, comment_id, return thread = Thread(id=thread_id).retrieve() user = User.objects.get(id=user_id) - course = get_course_with_access(user, 'load', course_key, check_if_enrolled=True) - notification_sender = DiscussionNotificationSender(thread, course, user, parent_id, comment_id) + course = get_course_with_access(user, "load", course_key, check_if_enrolled=True) + notification_sender = DiscussionNotificationSender( + thread, course, user, parent_id, comment_id + ) notification_sender.send_new_comment_notification() notification_sender.send_new_response_notification() notification_sender.send_new_comment_on_response_notification() @@ -69,7 +79,9 @@ def send_response_notifications(thread_id, course_key_str, user_id, comment_id, @shared_task @set_code_owner_attribute -def send_response_endorsed_notifications(thread_id, response_id, course_key_str, endorsed_by): +def send_response_endorsed_notifications( + thread_id, response_id, course_key_str, endorsed_by +): """ Send notifications when a response is marked answered/ endorsed """ @@ -80,8 +92,10 @@ def send_response_endorsed_notifications(thread_id, response_id, course_key_str, response = Comment(id=response_id).retrieve() creator = User.objects.get(id=response.user_id) endorser = User.objects.get(id=endorsed_by) - course = get_course_with_access(creator, 'load', course_key, check_if_enrolled=True) - notification_sender = DiscussionNotificationSender(thread, course, creator, comment_id=response_id) + course = get_course_with_access(creator, "load", course_key, check_if_enrolled=True) + notification_sender = DiscussionNotificationSender( + thread, course, creator, comment_id=response_id + ) # skip sending notification to author of thread if they are the same as the author of the response if response.user_id != thread.user_id: # sends notification to author of thread @@ -92,22 +106,177 @@ def send_response_endorsed_notifications(thread_id, response_id, course_key_str, notification_sender.send_response_endorsed_notification() +@shared_task( + bind=True, # Enable retry context and access to task instance + max_retries=3, # Retry up to 3 times on failure + default_retry_delay=60, # Wait 60 seconds between retries + autoretry_for=(OSError, TimeoutError), # Only retry on transient network/IO errors + retry_backoff=True, # Exponential backoff between retries + retry_jitter=True, # Add randomization to retry delays +) +@set_code_owner_attribute +def delete_course_post_for_user( # pylint: disable=too-many-statements + self, + user_id, + username=None, + course_ids=None, + event_data=None, + # NEW PARAMETERS (backward compatible - all have defaults): + ban_user=False, + ban_scope='course', + moderator_id=None, + reason=None, +): + """ + Delete all discussion posts for a user and optionally ban them. + + BACKWARD COMPATIBLE: Existing callers without ban_user parameter + will experience no change in behavior. + + Args: + self: Task instance (when bind=True) + user_id: User whose posts to delete + username: Username of the user (optional, will be fetched if not provided) + course_ids: List of course IDs (API sends single course wrapped in array) + event_data: Event tracking metadata + ban_user: If True, create ban record (NEW) + ban_scope: 'course' or 'organization' (NEW) + moderator_id: Moderator applying ban (NEW) + reason: Ban reason (NEW) + """ + event_data = event_data or {} + log.info( + f"<> Deleting all posts for {username} in course {course_ids}" + ) + # Get triggered_by user_id from event_data for audit trail + deleted_by_user_id = event_data.get("triggered_by_user_id") if event_data else None + threads_deleted = Thread.delete_user_threads( + user_id, course_ids, deleted_by=deleted_by_user_id + ) + comments_deleted = Comment.delete_user_comments( + user_id, course_ids, deleted_by=deleted_by_user_id + ) + log.info( + f"<> Deleted {threads_deleted} posts and {comments_deleted} comments for {username} " + f"in course {course_ids}" + ) + + # Create ban record if requested + ban_id = None + ban_error = None + if ban_user: + try: + from forum import api as forum_api + + # Get user objects + target_user = User.objects.get(id=user_id) + moderator = User.objects.get(id=moderator_id) if moderator_id else None + + # Parse course key + course_key = CourseKey.from_string(course_ids[0]) if course_ids else None + + # Create ban using forum API + ban_result = forum_api.ban_user( + user=target_user, + banned_by=moderator, + course_id=course_key, + scope=ban_scope, + reason=reason or "Bulk delete and ban operation" + ) + + ban_id = ban_result.get('id') + + log.info( + f"<> Created {ban_scope}-level ban (ID: {ban_id}) " + f"for user {username} (ID: {user_id}) after deleting {threads_deleted + comments_deleted} items" + ) + + # Send escalation email (non-blocking) + try: + from lms.djangoapps.discussion.rest_api.emails import send_ban_escalation_email + + send_ban_escalation_email( + banned_user_id=user_id, + moderator_id=moderator_id, + course_id=course_ids[0] if course_ids else None, + scope=ban_scope, + reason=reason, + threads_deleted=threads_deleted, + comments_deleted=comments_deleted, + ) + except Exception as email_exc: # pylint: disable=broad-except + log.error( + "<> Failed to send ban escalation email for user %s (ID: %s): %s", + username, + user_id, + email_exc, + exc_info=True, + ) + + except Exception as e: # pylint: disable=broad-except + ban_error = str(e) + log.error( + f"<> Failed to create ban for user {username} (ID: {user_id}): {e}", + exc_info=True + ) + # Don't fail the entire task if ban creation fails + # Discussions are already deleted, so we log the error and continue + + event_data.update( + { + "number_of_posts_deleted": threads_deleted, + "number_of_comments_deleted": comments_deleted, + "ban_user": ban_user, + "ban_scope": ban_scope if ban_user else None, + "ban_id": ban_id if ban_user else None, + "ban_error": ban_error if ban_error else None, + } + ) + event_name = "edx.discussion.bulk_delete_user_posts" + tracker.emit(event_name, event_data) + segment.track("None", event_name, event_data) + + # Return task result for monitoring + return { + "threads_deleted": threads_deleted, + "comments_deleted": comments_deleted, + "ban_created": bool(ban_id), + "ban_id": ban_id, + "ban_error": ban_error, + } + + @shared_task @set_code_owner_attribute -def delete_course_post_for_user(user_id, username, course_ids, event_data=None): +def restore_course_post_for_user(user_id, username, course_ids, event_data=None): """ - Deletes all posts for user in a course. + Restores all soft-deleted posts for user in a course by setting is_deleted=False. """ event_data = event_data or {} - log.info(f"<> Deleting all posts for {username} in course {course_ids}") - threads_deleted = Thread.delete_user_threads(user_id, course_ids) - comments_deleted = Comment.delete_user_comments(user_id, course_ids) - log.info(f"<> Deleted {threads_deleted} posts and {comments_deleted} comments for {username} " - f"in course {course_ids}") - event_data.update({ - "number_of_posts_deleted": threads_deleted, - "number_of_comments_deleted": comments_deleted, - }) - event_name = 'edx.discussion.bulk_delete_user_posts' + log.info( + "<> Restoring all posts for %s in course %s", username, course_ids + ) + # Get triggered_by user_id from event_data for audit trail + restored_by_user_id = event_data.get("triggered_by_user_id") if event_data else None + threads_restored = Thread.restore_user_deleted_threads( + user_id, course_ids, restored_by=restored_by_user_id + ) + comments_restored = Comment.restore_user_deleted_comments( + user_id, course_ids, restored_by=restored_by_user_id + ) + log.info( + "<> Restored %s posts and %s comments for %s in course %s", + threads_restored, + comments_restored, + username, + course_ids, + ) + event_data.update( + { + "number_of_posts_restored": threads_restored, + "number_of_comments_restored": comments_restored, + } + ) + event_name = "edx.discussion.bulk_restore_user_posts" tracker.emit(event_name, event_data) - segment.track('None', event_name, event_data) + segment.track("None", event_name, event_data) diff --git a/lms/djangoapps/discussion/rest_api/tests/test_api.py b/lms/djangoapps/discussion/rest_api/tests/test_api.py index d23a6a06b1b5..9018ad3945d2 100644 --- a/lms/djangoapps/discussion/rest_api/tests/test_api.py +++ b/lms/djangoapps/discussion/rest_api/tests/test_api.py @@ -9,6 +9,7 @@ import ddt import httpretty import pytest +from django.core.exceptions import ValidationError from django.test import override_settings from django.contrib.auth import get_user_model from django.test.client import RequestFactory @@ -30,6 +31,8 @@ from common.djangoapps.util.testing import UrlResetMixin from lms.djangoapps.discussion.django_comment_client.tests.utils import ForumsEnableMixin from lms.djangoapps.discussion.rest_api.api import ( + create_comment, + create_thread, get_course, get_course_topics, get_user_comments, @@ -37,6 +40,7 @@ from lms.djangoapps.discussion.rest_api.exceptions import ( DiscussionDisabledError, ) +from rest_framework.exceptions import PermissionDenied from lms.djangoapps.discussion.rest_api.tests.utils import ( CommentsServiceMockMixin, make_minimal_cs_comment, @@ -50,6 +54,10 @@ FORUM_ROLE_STUDENT, Role ) +from openedx.core.djangoapps.django_comment_common.comment_client.utils import ( + CommentClient500Error, + CommentClientRequestError, +) from openedx.core.lib.exceptions import CourseNotFoundError, PageNotFoundError User = get_user_model() @@ -132,6 +140,7 @@ def test_basic(self): assert get_course(self.request, self.course.id) == { 'id': str(self.course.id), 'is_posting_enabled': True, + 'is_user_banned': False, 'blackouts': [], 'thread_list_url': 'http://testserver/api/discussion/v1/threads/?course_id=course-v1%3Ax%2By%2Bz', 'following_thread_list_url': @@ -159,7 +168,8 @@ def test_basic(self): }, "is_email_verified": True, "only_verified_users_can_post": False, - "content_creation_rate_limited": False + "content_creation_rate_limited": False, + "enable_discussion_ban": False, } @ddt.data( @@ -752,3 +762,117 @@ def test_call_with_non_existent_course(self): course_key=CourseKey.from_string("course-v1:x+y+z"), page=2, ) + + +def test_create_thread_denies_banned_user(): + request = RequestFactory().post('/dummy') + request.user = mock.Mock() + + with mock.patch( + "lms.djangoapps.discussion.rest_api.api._get_course", + return_value=mock.Mock(), + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api.get_context", + return_value={}, + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api.discussion_open_for_user", + return_value=True, + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api._check_initializable_thread_fields", + side_effect=ValidationError("downstream validation"), + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api.ENABLE_DISCUSSION_BAN.is_enabled", + return_value=True, + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api.forum_api.is_user_banned", + return_value=True, + create=True, + ): + with pytest.raises(PermissionDenied, match="You are banned from posting"): + create_thread(request, {"course_id": "course-v1:x+y+z"}) + + +def test_create_comment_denies_banned_user(): + request = RequestFactory().post('/dummy') + request.user = mock.Mock() + course = mock.Mock() + course.id = CourseKey.from_string("course-v1:x+y+z") + + with mock.patch( + "lms.djangoapps.discussion.rest_api.api._get_thread_and_context", + return_value=({"closed": False}, {"course": course}), + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api.discussion_open_for_user", + return_value=True, + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api._check_initializable_comment_fields", + side_effect=ValidationError("downstream validation"), + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api.ENABLE_DISCUSSION_BAN.is_enabled", + return_value=True, + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api.forum_api.is_user_banned", + return_value=True, + create=True, + ): + with pytest.raises(PermissionDenied, match="You are banned from posting"): + create_comment(request, {"thread_id": "test_thread"}) + + +def test_create_thread_ban_check_backend_error_fails_open(): + request = RequestFactory().post('/dummy') + request.user = mock.Mock(id=123) + + with mock.patch( + "lms.djangoapps.discussion.rest_api.api._get_course", + return_value=mock.Mock(), + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api.get_context", + return_value={}, + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api.discussion_open_for_user", + return_value=True, + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api._check_initializable_thread_fields", + side_effect=ValidationError("downstream validation"), + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api.ENABLE_DISCUSSION_BAN.is_enabled", + return_value=True, + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api.forum_api.is_user_banned", + side_effect=CommentClientRequestError("temporary backend failure"), + create=True, + ), mock.patch("lms.djangoapps.discussion.rest_api.api.log.warning") as warning_log: + with pytest.raises(ValidationError): + create_thread(request, {"course_id": "course-v1:x+y+z"}) + + warning_log.assert_called_once() + + +def test_create_comment_ban_check_backend_error_fails_open(): + request = RequestFactory().post('/dummy') + request.user = mock.Mock(id=123) + course = mock.Mock() + course.id = CourseKey.from_string("course-v1:x+y+z") + + with mock.patch( + "lms.djangoapps.discussion.rest_api.api._get_thread_and_context", + return_value=({"closed": False}, {"course": course}), + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api.discussion_open_for_user", + return_value=True, + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api._check_initializable_comment_fields", + side_effect=ValidationError("downstream validation"), + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api.ENABLE_DISCUSSION_BAN.is_enabled", + return_value=True, + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api.forum_api.is_user_banned", + side_effect=CommentClient500Error("temporary backend failure"), + create=True, + ), mock.patch("lms.djangoapps.discussion.rest_api.api.log.warning") as warning_log: + with pytest.raises(ValidationError): + create_comment(request, {"thread_id": "test_thread"}) + + warning_log.assert_called_once() diff --git a/lms/djangoapps/discussion/rest_api/tests/test_api_v2.py b/lms/djangoapps/discussion/rest_api/tests/test_api_v2.py index f5b15b905639..7f85c2c8c210 100644 --- a/lms/djangoapps/discussion/rest_api/tests/test_api_v2.py +++ b/lms/djangoapps/discussion/rest_api/tests/test_api_v2.py @@ -10,34 +10,20 @@ import random from datetime import datetime, timedelta from unittest import mock -from urllib.parse import parse_qs, urlencode, urlparse, urlunparse import ddt import httpretty import pytest -from django.test import override_settings from django.contrib.auth import get_user_model from django.core.exceptions import ValidationError from django.test.client import RequestFactory -from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.locator import CourseLocator from pytz import UTC from rest_framework.exceptions import PermissionDenied -from xmodule.modulestore import ModuleStoreEnum -from xmodule.modulestore.django import modulestore -from xmodule.modulestore.tests.django_utils import ( - ModuleStoreTestCase, - SharedModuleStoreTestCase, -) -from xmodule.modulestore.tests.factories import CourseFactory, BlockFactory -from xmodule.partitions.partitions import Group, UserPartition - from common.djangoapps.student.tests.factories import ( AdminFactory, - BetaTesterFactory, CourseEnrollmentFactory, - StaffFactory, UserFactory, ) from common.djangoapps.util.testing import UrlResetMixin @@ -45,10 +31,6 @@ from lms.djangoapps.discussion.django_comment_client.tests.utils import ( ForumsEnableMixin, ) -from lms.djangoapps.discussion.tests.utils import ( - make_minimal_cs_comment, - make_minimal_cs_thread, -) from lms.djangoapps.discussion.rest_api import api from lms.djangoapps.discussion.rest_api.api import ( create_comment, @@ -56,12 +38,9 @@ delete_comment, delete_thread, get_comment_list, - get_course, - get_course_topics, get_course_topics_v2, get_thread, get_thread_list, - get_user_comments, update_comment, update_thread, ) @@ -73,18 +52,19 @@ ) from lms.djangoapps.discussion.rest_api.serializers import TopicOrdering from lms.djangoapps.discussion.rest_api.tests.utils import ( - CommentsServiceMockMixin, ForumMockUtilsMixin, make_paginated_api_response, - parsed_body, ) -from openedx.core.djangoapps.course_groups.models import CourseUserGroupPartitionGroup +from lms.djangoapps.discussion.tests.utils import ( + make_minimal_cs_comment, + make_minimal_cs_thread, +) from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory from openedx.core.djangoapps.discussions.models import ( DiscussionsConfiguration, DiscussionTopicLink, - Provider, PostingRestriction, + Provider, ) from openedx.core.djangoapps.discussions.tasks import ( update_discussions_settings_from_course_task, @@ -98,6 +78,13 @@ Role, ) from openedx.core.lib.exceptions import CourseNotFoundError, PageNotFoundError +from xmodule.modulestore import ModuleStoreEnum +from xmodule.modulestore.django import modulestore +from xmodule.modulestore.tests.django_utils import ( + ModuleStoreTestCase, + SharedModuleStoreTestCase, +) +from xmodule.modulestore.tests.factories import BlockFactory, CourseFactory User = get_user_model() @@ -274,7 +261,11 @@ def test_basic(self, mock_emit): ) self.register_post_thread_response(cs_thread) with self.assert_signal_sent( - api, "thread_created", sender=None, user=self.user, exclude_args=("post", "notify_all_learners") + api, + "thread_created", + sender=None, + user=self.user, + exclude_args=("post", "notify_all_learners"), ): actual = create_thread(self.request, self.minimal_data) expected = self.expected_thread_data( @@ -353,7 +344,11 @@ def test_basic_in_blackout_period_with_user_access(self, mock_emit): ) with self.assert_signal_sent( - api, "thread_created", sender=None, user=self.user, exclude_args=("post", "notify_all_learners") + api, + "thread_created", + sender=None, + user=self.user, + exclude_args=("post", "notify_all_learners"), ): actual = create_thread(self.request, self.minimal_data) expected = self.expected_thread_data( @@ -363,6 +358,7 @@ def test_basic_in_blackout_period_with_user_access(self, mock_emit): "course_id": str(self.course.id), "comment_list_url": "http://testserver/api/discussion/v1/comments/?thread_id=test_id", "read": True, + "learner_status": "staff", "editable_fields": [ "abuse_flagged", "anonymous", @@ -378,6 +374,7 @@ def test_basic_in_blackout_period_with_user_access(self, mock_emit): "type", "voted", ], + "is_deleted": False, } ) assert actual == expected @@ -429,7 +426,11 @@ def test_title_truncation(self, mock_emit): ) self.register_post_thread_response(cs_thread) with self.assert_signal_sent( - api, "thread_created", sender=None, user=self.user, exclude_args=("post", "notify_all_learners") + api, + "thread_created", + sender=None, + user=self.user, + exclude_args=("post", "notify_all_learners"), ): create_thread(self.request, data) event_name, event_data = mock_emit.call_args[0] @@ -689,6 +690,9 @@ def test_success(self, parent_id, mock_emit): "parent_id": parent_id, "author": self.user.username, "author_label": None, + "is_author_banned": False, + "author_ban_scope": None, + "learner_status": "new", "created_at": "2015-05-27T00:00:00Z", "updated_at": "2015-05-27T00:00:00Z", "raw_body": "Test body", @@ -716,6 +720,10 @@ def test_success(self, parent_id, mock_emit): "image_url_medium": "http://testserver/static/default_50.png", "image_url_small": "http://testserver/static/default_30.png", }, + "is_deleted": False, + "deleted_at": None, + "deleted_by": None, + "deleted_by_label": None, } assert actual == expected @@ -796,6 +804,9 @@ def test_success_in_black_out_with_user_access(self, parent_id, mock_emit): "parent_id": parent_id, "author": self.user.username, "author_label": "Moderator", + "is_author_banned": False, + "author_ban_scope": None, + "learner_status": "staff", "created_at": "2015-05-27T00:00:00Z", "updated_at": "2015-05-27T00:00:00Z", "raw_body": "Test body", @@ -823,6 +834,10 @@ def test_success_in_black_out_with_user_access(self, parent_id, mock_emit): "image_url_medium": "http://testserver/static/default_50.png", "image_url_small": "http://testserver/static/default_30.png", }, + "is_deleted": False, + "deleted_at": None, + "deleted_by": None, + "deleted_by_label": None, } assert actual == expected @@ -911,7 +926,9 @@ def test_endorsed(self, role_name, is_thread_author, thread_type): ) try: create_comment(self.request, data) - last_commemt_params = self.get_mock_func_calls("create_parent_comment")[-1][1] + last_commemt_params = self.get_mock_func_calls("create_parent_comment")[-1][ + 1 + ] assert last_commemt_params["endorsed"] assert not expected_error except ValidationError: @@ -1799,6 +1816,9 @@ def test_basic(self, parent_id): "parent_id": parent_id, "author": self.user.username, "author_label": None, + "is_author_banned": False, + "author_ban_scope": None, + "learner_status": "new", "created_at": "2015-06-03T00:00:00Z", "updated_at": "2015-06-03T00:00:00Z", "raw_body": "Edited body", @@ -1824,6 +1844,10 @@ def test_basic(self, parent_id): "image_url_medium": "http://testserver/static/default_50.png", "image_url_small": "http://testserver/static/default_30.png", }, + "is_deleted": False, + "deleted_at": None, + "deleted_by": None, + "deleted_by_label": None, } assert actual == expected params = { @@ -1884,7 +1908,7 @@ def test_abuse_flagged(self, old_flagged, new_flagged, mock_emit): else "edx.forum.response.unreported" ) expected_event_data = { - "discussion": {'id': 'test_thread'}, + "discussion": {"id": "test_thread"}, "body": "Original body", "id": "test_comment", "content_type": "Response", @@ -1947,7 +1971,7 @@ def test_comment_un_abuse_flag_for_moderator_role( "body": "Original body", "id": "test_comment", "content_type": "Response", - "discussion": {'id': 'test_thread'}, + "discussion": {"id": "test_thread"}, "commentable_id": "dummy", "truncated": False, "url": "", @@ -2366,6 +2390,7 @@ def test_basic(self, mock_emit): params = { "thread_id": self.thread_id, "course_id": str(self.course.id), + "deleted_by": str(self.user.id), } self.check_mock_called_with("delete_thread", -1, **params) @@ -2553,6 +2578,7 @@ def test_basic(self, mock_emit): params = { "comment_id": self.comment_id, "course_id": str(self.course.id), + "deleted_by": str(self.user.id), } self.check_mock_called_with("delete_comment", -1, **params) @@ -2722,6 +2748,7 @@ def register_thread(self, overrides=None): "title": "Test Title", "body": "Test body", "resp_total": 0, + "is_deleted": False, } ) cs_data.update(overrides or {}) @@ -2756,6 +2783,7 @@ def test_nonauthor_enrolled_in_course(self): "voted", ], "unread_comment_count": 1, + "is_deleted": None, } ) self.check_mock_called("get_thread") @@ -2917,6 +2945,7 @@ def test_get_threads_by_topic_id(self): "page": 1, "per_page": 1, "commentable_ids": ["topic_x", "topic_meow"], + "show_deleted": False, } self.check_mock_called_with( "get_user_threads", @@ -2932,6 +2961,7 @@ def test_basic_query_params(self): "sort_key": "activity", "page": 6, "per_page": 14, + "show_deleted": False, } self.check_mock_called_with( "get_user_threads", @@ -2959,6 +2989,7 @@ def test_thread_content(self): "read": True, "created_at": "2015-04-28T00:00:00Z", "updated_at": "2015-04-28T11:11:11Z", + "is_deleted": False, } ), make_minimal_cs_thread( @@ -2976,6 +3007,7 @@ def test_thread_content(self): "comments_count": 18, "created_at": "2015-04-28T22:22:22Z", "updated_at": "2015-04-28T00:33:33Z", + "is_deleted": False, } ), ] @@ -3002,6 +3034,7 @@ def test_thread_content(self): "updated_at": "2015-04-28T11:11:11Z", "abuse_flagged_count": None, "can_delete": False, + "is_deleted": None, } ), self.expected_thread_data( @@ -3036,6 +3069,7 @@ def test_thread_content(self): ], "abuse_flagged_count": None, "can_delete": False, + "is_deleted": None, } ), ] @@ -3072,10 +3106,10 @@ def test_request_group(self, role_name, course_is_cohorted): self.get_thread_list([], course=cohort_course) thread_func_params = self.get_mock_func_calls("get_user_threads")[-1][1] actual_has_group = "group_id" in thread_func_params - expected_has_group = ( - course_is_cohorted and role_name in ( - FORUM_ROLE_STUDENT, FORUM_ROLE_COMMUNITY_TA, FORUM_ROLE_GROUP_MODERATOR - ) + expected_has_group = course_is_cohorted and role_name in ( + FORUM_ROLE_STUDENT, + FORUM_ROLE_COMMUNITY_TA, + FORUM_ROLE_GROUP_MODERATOR, ) assert actual_has_group == expected_has_group @@ -3140,6 +3174,7 @@ def test_text_search(self, text_search_rewrite): "page": 1, "per_page": 10, "text": "test search string", + "show_deleted": False, } self.check_mock_called_with( "search_threads", @@ -3166,6 +3201,7 @@ def test_filter_threads_by_author(self): "page": 1, "per_page": 10, "author_id": str(self.user.id), + "show_deleted": False, } self.check_mock_called_with( "get_user_threads", @@ -3212,6 +3248,7 @@ def test_thread_type(self, thread_type): "page": 1, "per_page": 10, "thread_type": thread_type, + "show_deleted": False, } if thread_type is None: @@ -3249,6 +3286,7 @@ def test_flagged(self, flagged_boolean): "page": 1, "per_page": 10, "flagged": flagged_boolean, + "show_deleted": False, } if flagged_boolean is None: @@ -3289,6 +3327,7 @@ def test_flagged_count(self, role): "count_flagged": True, "page": 1, "per_page": 10, + "show_deleted": False, } self.check_mock_called_with( @@ -3337,6 +3376,7 @@ def test_following(self): "sort_key": "activity", "page": 1, "per_page": 11, + "show_deleted": False, } self.check_mock_called_with("get_user_subscriptions", -1, **params) @@ -3364,6 +3404,7 @@ def test_view_query(self, query): "page": 1, "per_page": 11, query: True, + "show_deleted": False, } self.check_mock_called_with( "get_user_threads", @@ -3405,6 +3446,7 @@ def test_order_by_query(self, http_query, cc_query): "sort_key": cc_query, "page": 1, "per_page": 11, + "show_deleted": False, } self.check_mock_called_with( "get_user_threads", @@ -3437,6 +3479,7 @@ def test_order_direction(self): "sort_key": "activity", "page": 1, "per_page": 11, + "show_deleted": False, } self.check_mock_called_with( "get_user_threads", @@ -3711,6 +3754,7 @@ def get_source_and_expected_comments(self): "votes": {"up_count": 4}, "child_count": 0, "children": [], + "is_deleted": False, }, { "type": "comment", @@ -3728,6 +3772,7 @@ def get_source_and_expected_comments(self): "votes": {"up_count": 7}, "child_count": 0, "children": [], + "is_deleted": False, }, ] expected_comments = [ @@ -3737,6 +3782,9 @@ def get_source_and_expected_comments(self): "parent_id": None, "author": self.author.username, "author_label": None, + "is_author_banned": False, + "author_ban_scope": None, + "learner_status": "new", "created_at": "2015-05-11T00:00:00Z", "updated_at": "2015-05-11T11:11:11Z", "raw_body": "Test body", @@ -3764,6 +3812,10 @@ def get_source_and_expected_comments(self): "image_url_medium": "http://testserver/static/default_50.png", "image_url_small": "http://testserver/static/default_30.png", }, + "is_deleted": None, + "deleted_at": None, + "deleted_by": None, + "deleted_by_label": None, }, { "id": "test_comment_2", @@ -3771,6 +3823,9 @@ def get_source_and_expected_comments(self): "parent_id": None, "author": None, "author_label": None, + "is_author_banned": False, + "author_ban_scope": None, + "learner_status": "anonymous", "created_at": "2015-05-11T22:22:22Z", "updated_at": "2015-05-11T33:33:33Z", "raw_body": "More content", @@ -3798,6 +3853,10 @@ def get_source_and_expected_comments(self): "image_url_medium": "http://testserver/static/default_50.png", "image_url_small": "http://testserver/static/default_30.png", }, + "is_deleted": None, + "deleted_at": None, + "deleted_by": None, + "deleted_by_label": None, }, ] return source_comments, expected_comments diff --git a/lms/djangoapps/discussion/rest_api/tests/test_forms.py b/lms/djangoapps/discussion/rest_api/tests/test_forms.py index 3be65964b6b9..33359337933b 100644 --- a/lms/djangoapps/discussion/rest_api/tests/test_forms.py +++ b/lms/djangoapps/discussion/rest_api/tests/test_forms.py @@ -2,7 +2,6 @@ Tests for Discussion API forms """ - import itertools from unittest import TestCase from urllib.parse import urlencode @@ -12,9 +11,9 @@ from opaque_keys.edx.locator import CourseLocator from lms.djangoapps.discussion.rest_api.forms import ( - UserCommentListGetForm, CommentListGetForm, ThreadListGetForm, + UserCommentListGetForm, ) from openedx.core.djangoapps.util.test_forms import FormTestMixin @@ -36,7 +35,9 @@ def test_missing_page_size(self): def test_zero_page_size(self): self.form_data["page_size"] = "0" - self.assert_error("page_size", "Ensure this value is greater than or equal to 1.") + self.assert_error( + "page_size", "Ensure this value is greater than or equal to 1." + ) def test_excessive_page_size(self): self.form_data["page_size"] = "101" @@ -46,6 +47,7 @@ def test_excessive_page_size(self): @ddt.ddt class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): """Tests for ThreadListGetForm""" + FORM_CLASS = ThreadListGetForm def setUp(self): @@ -58,37 +60,41 @@ def setUp(self): "page_size": "13", } ), - mutable=True + mutable=True, ) def test_basic(self): form = self.get_form(expected_valid=True) assert form.cleaned_data == { - 'course_id': CourseLocator.from_string('Foo/Bar/Baz'), - 'page': 2, - 'page_size': 13, - 'count_flagged': None, - 'topic_id': set(), - 'text_search': '', - 'following': None, - 'author': '', - 'thread_type': '', - 'flagged': None, - 'view': '', - 'order_by': 'last_activity_at', - 'order_direction': 'desc', - 'requested_fields': set() + "course_id": CourseLocator.from_string("Foo/Bar/Baz"), + "page": 2, + "page_size": 13, + "count_flagged": None, + "topic_id": set(), + "text_search": "", + "following": None, + "author": "", + "thread_type": "", + "flagged": None, + "show_deleted": None, + "view": "", + "order_by": "last_activity_at", + "order_direction": "desc", + "requested_fields": set(), } def test_topic_id(self): self.form_data.setlist("topic_id", ["example topic_id", "example 2nd topic_id"]) form = self.get_form(expected_valid=True) - assert form.cleaned_data['topic_id'] == {'example topic_id', 'example 2nd topic_id'} + assert form.cleaned_data["topic_id"] == { + "example topic_id", + "example 2nd topic_id", + } def test_text_search(self): self.form_data["text_search"] = "test search string" form = self.get_form(expected_valid=True) - assert form.cleaned_data['text_search'] == 'test search string' + assert form.cleaned_data["text_search"] == "test search string" def test_missing_course_id(self): self.form_data.pop("course_id") @@ -109,7 +115,10 @@ def test_thread_type(self, value): def test_thread_type_invalid(self): self.form_data["thread_type"] = "invalid-option" - self.assert_error("thread_type", "Select a valid choice. invalid-option is not one of the available choices.") + self.assert_error( + "thread_type", + "Select a valid choice. invalid-option is not one of the available choices.", + ) @ddt.data("True", "true", 1, True) def test_flagged_true(self, value): @@ -133,7 +142,9 @@ def test_following_true(self, value): @ddt.data("False", "false", 0, False) def test_following_false(self, value): self.form_data["following"] = value - self.assert_error("following", "The value of the 'following' parameter must be true.") + self.assert_error( + "following", "The value of the 'following' parameter must be true." + ) def test_invalid_following(self): self.form_data["following"] = "invalid-boolean" @@ -144,25 +155,28 @@ def test_mutually_exclusive(self, params): self.form_data.update({param: "True" for param in params}) self.assert_error( "__all__", - "The following query parameters are mutually exclusive: topic_id, text_search, following" + "The following query parameters are mutually exclusive: topic_id, text_search, following", ) def test_invalid_view_choice(self): self.form_data["view"] = "not_a_valid_choice" - self.assert_error("view", "Select a valid choice. not_a_valid_choice is not one of the available choices.") + self.assert_error( + "view", + "Select a valid choice. not_a_valid_choice is not one of the available choices.", + ) def test_invalid_sort_by_choice(self): self.form_data["order_by"] = "not_a_valid_choice" self.assert_error( "order_by", - "Select a valid choice. not_a_valid_choice is not one of the available choices." + "Select a valid choice. not_a_valid_choice is not one of the available choices.", ) def test_invalid_sort_direction_choice(self): self.form_data["order_direction"] = "not_a_valid_choice" self.assert_error( "order_direction", - "Select a valid choice. not_a_valid_choice is not one of the available choices." + "Select a valid choice. not_a_valid_choice is not one of the available choices.", ) @ddt.data( @@ -181,12 +195,13 @@ def test_valid_choice_fields(self, field, value): def test_requested_fields(self): self.form_data["requested_fields"] = "profile_image" form = self.get_form(expected_valid=True) - assert form.cleaned_data['requested_fields'] == {'profile_image'} + assert form.cleaned_data["requested_fields"] == {"profile_image"} @ddt.ddt class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): """Tests for CommentListGetForm""" + FORM_CLASS = CommentListGetForm def setUp(self): @@ -202,13 +217,14 @@ def setUp(self): def test_basic(self): form = self.get_form(expected_valid=True) assert form.cleaned_data == { - 'thread_id': 'deadbeef', - 'endorsed': False, - 'page': 2, - 'page_size': 13, - 'flagged': False, - 'requested_fields': set(), - 'merge_question_type_responses': False + "thread_id": "deadbeef", + "endorsed": False, + "page": 2, + "page_size": 13, + "flagged": False, + "requested_fields": set(), + "merge_question_type_responses": False, + "show_deleted": None, } def test_missing_thread_id(self): @@ -236,12 +252,13 @@ def test_invalid_endorsed(self): def test_requested_fields(self): self.form_data["requested_fields"] = {"profile_image"} form = self.get_form(expected_valid=True) - assert form.cleaned_data['requested_fields'] == {'profile_image'} + assert form.cleaned_data["requested_fields"] == {"profile_image"} @ddt.ddt class UserCommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): """Tests for UserCommentListGetForm""" + FORM_CLASS = UserCommentListGetForm def setUp(self): @@ -256,11 +273,11 @@ def setUp(self): def test_basic(self): form = self.get_form(expected_valid=True) assert form.cleaned_data == { - 'course_id': CourseLocator.from_string('a/b/c'), - 'flagged': False, - 'page': 2, - 'page_size': 13, - 'requested_fields': set() + "course_id": CourseLocator.from_string("a/b/c"), + "flagged": False, + "page": 2, + "page_size": 13, + "requested_fields": set(), } def test_missing_flagged(self): @@ -280,7 +297,7 @@ def test_flagged_true(self, value): def test_requested_fields(self): self.form_data["requested_fields"] = {"profile_image"} form = self.get_form(expected_valid=True) - assert form.cleaned_data['requested_fields'] == {'profile_image'} + assert form.cleaned_data["requested_fields"] == {"profile_image"} def test_missing_course_id(self): self.form_data.pop("course_id") diff --git a/lms/djangoapps/discussion/rest_api/tests/test_moderation_emails.py b/lms/djangoapps/discussion/rest_api/tests/test_moderation_emails.py new file mode 100644 index 000000000000..9256596945e2 --- /dev/null +++ b/lms/djangoapps/discussion/rest_api/tests/test_moderation_emails.py @@ -0,0 +1,238 @@ +""" +Tests for discussion moderation email notifications. +""" +from unittest import mock +from django.test import override_settings +from django.core import mail + +from lms.djangoapps.discussion.rest_api.emails import send_ban_escalation_email +from common.djangoapps.student.tests.factories import UserFactory +from xmodule.modulestore.tests.factories import CourseFactory +from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase + + +class BanEscalationEmailTest(ModuleStoreTestCase): + """Tests for send_ban_escalation_email function.""" + + def setUp(self): + super().setUp() + self.course = CourseFactory.create(org='TestX', number='CS101', run='2024') + self.course_key = str(self.course.id) + self.banned_user = UserFactory.create(username='spammer', email='spammer@example.com') + self.moderator = UserFactory.create(username='moderator', email='mod@example.com') + + @override_settings(DISCUSSION_MODERATION_BAN_EMAIL_ENABLED=False) + def test_email_disabled_by_setting(self): + """Test that email is not sent when DISCUSSION_MODERATION_BAN_EMAIL_ENABLED is False.""" + # Clear outbox + mail.outbox = [] + + # Try to send email + send_ban_escalation_email( + banned_user_id=self.banned_user.id, + moderator_id=self.moderator.id, + course_id=self.course_key, + scope='course', + reason='Spam', + threads_deleted=5, + comments_deleted=10 + ) + + # No email should be sent + self.assertEqual(len(mail.outbox), 0) + + @override_settings( + DISCUSSION_MODERATION_BAN_EMAIL_ENABLED=True, + DISCUSSION_MODERATION_ESCALATION_EMAIL='partner-support@edx.org' + ) + @mock.patch('lms.djangoapps.discussion.rest_api.emails.ace') + def test_email_sent_via_ace(self, mock_ace_module): + """Test that email is sent via ACE when available.""" + # Create mock ACE send function + mock_send = mock.MagicMock() + mock_ace_module.send = mock_send + + send_ban_escalation_email( + banned_user_id=self.banned_user.id, + moderator_id=self.moderator.id, + course_id=self.course_key, + scope='course', + reason='Posting scam links', + threads_deleted=3, + comments_deleted=7 + ) + + # ACE send should be called + mock_send.assert_called_once() + + # Get the message argument + call_args = mock_send.call_args + message = call_args[0][0] + + # Verify message properties + self.assertEqual(message.recipient.email_address, 'partner-support@edx.org') + self.assertEqual(message.context['banned_username'], 'spammer') + self.assertEqual(message.context['moderator_username'], 'moderator') + self.assertEqual(message.context['scope'], 'course') + self.assertEqual(message.context['reason'], 'Posting scam links') + self.assertEqual(message.context['threads_deleted'], 3) + self.assertEqual(message.context['comments_deleted'], 7) + self.assertEqual(message.context['total_deleted'], 10) + + @override_settings( + DISCUSSION_MODERATION_BAN_EMAIL_ENABLED=True, + DISCUSSION_MODERATION_ESCALATION_EMAIL='custom-support@example.com', + DEFAULT_FROM_EMAIL='noreply@edx.org' + ) + @mock.patch('lms.djangoapps.discussion.rest_api.emails.ace', None) + def test_email_fallback_to_django_mail(self): + """Test that email falls back to Django mail when ACE is not available.""" + # Clear outbox + mail.outbox = [] + + # Simulate ACE not being importable by making the import fail + import sys + original_modules = sys.modules.copy() + + # Remove ace modules if present + ace_modules = [key for key in sys.modules if key.startswith('edx_ace')] + for mod in ace_modules: + sys.modules.pop(mod, None) + + try: + send_ban_escalation_email( + banned_user_id=self.banned_user.id, + moderator_id=self.moderator.id, + course_id=self.course_key, + scope='organization', + reason='Multiple violations', + threads_deleted=15, + comments_deleted=25 + ) + finally: + # Restore modules + sys.modules.update(original_modules) + + # Email should be sent via Django + self.assertEqual(len(mail.outbox), 1) + + email = mail.outbox[0] + self.assertIn('custom-support@example.com', email.to) + self.assertEqual(email.from_email, 'noreply@edx.org') + self.assertIn('spammer', email.body) + self.assertIn('moderator', email.body) + self.assertIn('Multiple violations', email.body) + self.assertIn('ORGANIZATION', email.body) + self.assertIn('15', email.body) # threads_deleted + self.assertIn('25', email.body) # comments_deleted + + @override_settings( + DISCUSSION_MODERATION_BAN_EMAIL_ENABLED=True, + DISCUSSION_MODERATION_ESCALATION_EMAIL='support@example.com' + ) + @mock.patch('lms.djangoapps.discussion.rest_api.emails.ace', None) + def test_email_handles_missing_reason(self): + """Test that email handles empty/None reason gracefully.""" + mail.outbox = [] + + # Send with empty reason (will use Django mail since ace is None) + send_ban_escalation_email( + banned_user_id=self.banned_user.id, + moderator_id=self.moderator.id, + course_id=self.course_key, + scope='course', + reason='', + threads_deleted=1, + comments_deleted=0 + ) + + self.assertEqual(len(mail.outbox), 1) + email = mail.outbox[0] + # Should use default text + self.assertIn('No reason provided', email.body) + + @override_settings( + DISCUSSION_MODERATION_BAN_EMAIL_ENABLED=True, + DISCUSSION_MODERATION_ESCALATION_EMAIL='support@example.com' + ) + @mock.patch('lms.djangoapps.discussion.rest_api.emails.ace', None) + def test_email_with_org_level_ban(self): + """Test email for organization-level ban.""" + mail.outbox = [] + + send_ban_escalation_email( + banned_user_id=self.banned_user.id, + moderator_id=self.moderator.id, + course_id=self.course_key, + scope='organization', + reason='Org-wide spam campaign', + threads_deleted=50, + comments_deleted=100 + ) + + self.assertEqual(len(mail.outbox), 1) + email = mail.outbox[0] + self.assertIn('ORGANIZATION', email.body) + self.assertIn('Org-wide spam campaign', email.body) + + @override_settings( + DISCUSSION_MODERATION_BAN_EMAIL_ENABLED=True, + DISCUSSION_MODERATION_ESCALATION_EMAIL='support@example.com' + ) + @mock.patch('lms.djangoapps.discussion.rest_api.emails.ace', None) + def test_email_failure_logged(self): + """Test that email failures are properly logged.""" + with mock.patch('django.core.mail.send_mail', side_effect=Exception("SMTP error")): + with self.assertLogs('lms.djangoapps.discussion.rest_api.emails', level='ERROR') as logs: + with self.assertRaises(Exception): + send_ban_escalation_email( + banned_user_id=self.banned_user.id, + moderator_id=self.moderator.id, + course_id=self.course_key, + scope='course', + reason='Test', + threads_deleted=1, + comments_deleted=1 + ) + + # Verify error was logged + self.assertTrue(any('Failed to send ban escalation email' in log for log in logs.output)) + + @override_settings(DISCUSSION_MODERATION_BAN_EMAIL_ENABLED=True) + def test_email_with_invalid_user_id(self): + """Test that email handles invalid user IDs gracefully.""" + with self.assertRaises(Exception): + send_ban_escalation_email( + banned_user_id=99999, # Non-existent user + moderator_id=self.moderator.id, + course_id=self.course_key, + scope='course', + reason='Test', + threads_deleted=0, + comments_deleted=0 + ) + + @override_settings( + DISCUSSION_MODERATION_BAN_EMAIL_ENABLED=True, + DISCUSSION_MODERATION_ESCALATION_EMAIL='test@example.com' + ) + @mock.patch('lms.djangoapps.discussion.rest_api.emails.ace', None) + def test_email_subject_format(self): + """Test that email subject is properly formatted.""" + mail.outbox = [] + + send_ban_escalation_email( + banned_user_id=self.banned_user.id, + moderator_id=self.moderator.id, + course_id=self.course_key, + scope='course', + reason='Test ban', + threads_deleted=1, + comments_deleted=1 + ) + + self.assertEqual(len(mail.outbox), 1) + email = mail.outbox[0] + # Subject should contain username and course + self.assertIn('spammer', email.subject) + self.assertIn(self.course_key, email.subject) diff --git a/lms/djangoapps/discussion/rest_api/tests/test_moderation_permissions.py b/lms/djangoapps/discussion/rest_api/tests/test_moderation_permissions.py new file mode 100644 index 000000000000..73981fd5b2cb --- /dev/null +++ b/lms/djangoapps/discussion/rest_api/tests/test_moderation_permissions.py @@ -0,0 +1,203 @@ +""" +Tests for discussion moderation permissions. +""" +from unittest.mock import Mock + +from rest_framework.test import APIRequestFactory + +from common.djangoapps.student.roles import CourseStaffRole, CourseInstructorRole, GlobalStaff +from common.djangoapps.student.tests.factories import UserFactory +from lms.djangoapps.discussion.rest_api.permissions import ( + IsAllowedToBulkDelete, + can_take_action_on_spam, +) +from openedx.core.djangoapps.django_comment_common.models import Role +from xmodule.modulestore.tests.factories import CourseFactory +from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase + + +class CanTakeActionOnSpamTest(ModuleStoreTestCase): + """Tests for can_take_action_on_spam permission helper function.""" + + def setUp(self): + super().setUp() + self.course = CourseFactory.create(org='TestX', number='CS101', run='2024') + self.course_key = self.course.id + + def test_global_staff_has_permission(self): + """Global staff should have permission.""" + user = UserFactory.create(is_staff=True) + self.assertTrue(can_take_action_on_spam(user, self.course_key)) + + def test_global_staff_role_has_permission(self): + """Users with GlobalStaff role should have permission.""" + user = UserFactory.create() + GlobalStaff().add_users(user) + self.assertTrue(can_take_action_on_spam(user, self.course_key)) + + def test_course_staff_has_permission(self): + """Course staff should have permission for their course.""" + user = UserFactory.create() + CourseStaffRole(self.course_key).add_users(user) + self.assertTrue(can_take_action_on_spam(user, self.course_key)) + + def test_course_instructor_has_permission(self): + """Course instructors should have permission for their course.""" + user = UserFactory.create() + CourseInstructorRole(self.course_key).add_users(user) + self.assertTrue(can_take_action_on_spam(user, self.course_key)) + + def test_forum_moderator_has_permission(self): + """Forum moderators should have permission for their course.""" + user = UserFactory.create() + role = Role.objects.create(name='Moderator', course_id=self.course_key) + role.users.add(user) + self.assertTrue(can_take_action_on_spam(user, self.course_key)) + + def test_forum_administrator_has_permission(self): + """Forum administrators should have permission for their course.""" + user = UserFactory.create() + role = Role.objects.create(name='Administrator', course_id=self.course_key) + role.users.add(user) + self.assertTrue(can_take_action_on_spam(user, self.course_key)) + + def test_regular_student_no_permission(self): + """Regular students should not have permission.""" + user = UserFactory.create() + self.assertFalse(can_take_action_on_spam(user, self.course_key)) + + def test_community_ta_no_permission(self): + """Community TAs should not have bulk delete permission.""" + user = UserFactory.create() + role = Role.objects.create(name='Community TA', course_id=self.course_key) + role.users.add(user) + self.assertFalse(can_take_action_on_spam(user, self.course_key)) + + def test_staff_different_course_no_permission(self): + """Staff from a different course should not have permission.""" + other_course = CourseFactory.create(org='OtherX', number='CS201', run='2024') + user = UserFactory.create() + CourseStaffRole(other_course.id).add_users(user) + self.assertFalse(can_take_action_on_spam(user, self.course_key)) + + def test_accepts_string_course_id(self): + """Function should accept string course_id and convert it.""" + user = UserFactory.create() + CourseStaffRole(self.course_key).add_users(user) + self.assertTrue(can_take_action_on_spam(user, str(self.course_key))) + + +class IsAllowedToBulkDeleteTest(ModuleStoreTestCase): + """Tests for IsAllowedToBulkDelete permission class.""" + + def setUp(self): + super().setUp() + self.course = CourseFactory.create(org='TestX', number='CS101', run='2024') + self.course_key = str(self.course.id) + self.factory = APIRequestFactory() + self.permission = IsAllowedToBulkDelete() + + def _create_view_with_kwargs(self, course_id=None): + """Helper to create a mock view with kwargs.""" + view = Mock() + view.kwargs = {'course_id': course_id} if course_id else {} + return view + + def _create_request_with_data(self, user, course_id=None, method='POST'): + """Helper to create a request with data.""" + if method == 'POST': + request = self.factory.post('/api/discussion/v1/moderation/bulk-delete-ban/') + else: + request = self.factory.get('/api/discussion/v1/moderation/banned-users/') + + request.user = user + request.data = {'course_id': course_id} if course_id else {} + return request + + def test_unauthenticated_user_denied(self): + """Unauthenticated users should be denied.""" + request = self.factory.post('/api/discussion/v1/moderation/bulk-delete-ban/') + request.user = Mock(is_authenticated=False) + view = self._create_view_with_kwargs() + + self.assertFalse(self.permission.has_permission(request, view)) + + def test_global_staff_with_course_id_in_data(self): + """Global staff should have permission when course_id is in request data.""" + user = UserFactory.create(is_staff=True) + request = self._create_request_with_data(user, self.course_key) + view = self._create_view_with_kwargs() + + self.assertTrue(self.permission.has_permission(request, view)) + + def test_course_staff_with_course_id_in_data(self): + """Course staff should have permission when course_id is in request data.""" + user = UserFactory.create() + CourseStaffRole(self.course.id).add_users(user) + request = self._create_request_with_data(user, self.course_key) + view = self._create_view_with_kwargs() + + self.assertTrue(self.permission.has_permission(request, view)) + + def test_course_instructor_with_course_id_in_data(self): + """Course instructors should have permission when course_id is in request data.""" + user = UserFactory.create() + CourseInstructorRole(self.course.id).add_users(user) + request = self._create_request_with_data(user, self.course_key) + view = self._create_view_with_kwargs() + + self.assertTrue(self.permission.has_permission(request, view)) + + def test_forum_moderator_with_course_id_in_data(self): + """Forum moderators should have permission when course_id is in request data.""" + user = UserFactory.create() + role = Role.objects.create(name='Moderator', course_id=self.course.id) + role.users.add(user) + request = self._create_request_with_data(user, self.course_key) + view = self._create_view_with_kwargs() + + self.assertTrue(self.permission.has_permission(request, view)) + + def test_regular_student_denied(self): + """Regular students should be denied.""" + user = UserFactory.create() + request = self._create_request_with_data(user, self.course_key) + view = self._create_view_with_kwargs() + + self.assertFalse(self.permission.has_permission(request, view)) + + def test_course_id_in_url_kwargs(self): + """Permission should work when course_id is in URL kwargs.""" + user = UserFactory.create() + CourseStaffRole(self.course.id).add_users(user) + request = self.factory.get('/api/discussion/v1/moderation/banned-users/') + request.user = user + request.data = {} + request.query_params = {} + view = self._create_view_with_kwargs(self.course_key) + + self.assertTrue(self.permission.has_permission(request, view)) + + def test_no_course_id_only_global_staff_allowed(self): + """When no course_id provided, only global staff should be allowed.""" + # Global staff allowed + global_staff = UserFactory.create(is_staff=True) + request = self._create_request_with_data(global_staff) + view = self._create_view_with_kwargs() + self.assertTrue(self.permission.has_permission(request, view)) + + # Regular user denied + regular_user = UserFactory.create() + request = self._create_request_with_data(regular_user) + view = self._create_view_with_kwargs() + self.assertFalse(self.permission.has_permission(request, view)) + + def test_staff_different_course_denied(self): + """Staff from different course should be denied.""" + other_course = CourseFactory.create(org='OtherX', number='CS201', run='2024') + user = UserFactory.create() + CourseStaffRole(other_course.id).add_users(user) + request = self._create_request_with_data(user, self.course_key) + view = self._create_view_with_kwargs() + + self.assertFalse(self.permission.has_permission(request, view)) diff --git a/lms/djangoapps/discussion/rest_api/tests/test_permissions.py b/lms/djangoapps/discussion/rest_api/tests/test_permissions.py index 405726e2125b..e0a325a3fa3d 100644 --- a/lms/djangoapps/discussion/rest_api/tests/test_permissions.py +++ b/lms/djangoapps/discussion/rest_api/tests/test_permissions.py @@ -4,12 +4,16 @@ import itertools +from unittest.mock import Mock import ddt from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.factories import CourseFactory +from common.djangoapps.student.roles import CourseInstructorRole, CourseStaffRole +from common.djangoapps.student.tests.factories import UserFactory from lms.djangoapps.discussion.rest_api.permissions import ( + IsAllowedToRestore, can_delete, get_editable_fields, get_initializable_comment_fields, @@ -18,6 +22,12 @@ from openedx.core.djangoapps.django_comment_common.comment_client.comment import Comment from openedx.core.djangoapps.django_comment_common.comment_client.thread import Thread from openedx.core.djangoapps.django_comment_common.comment_client.user import User +from openedx.core.djangoapps.django_comment_common.models import ( + FORUM_ROLE_ADMINISTRATOR, + FORUM_ROLE_COMMUNITY_TA, + FORUM_ROLE_MODERATOR, + Role, +) def _get_context( @@ -202,3 +212,102 @@ def test_comment(self, is_author, is_thread_author, is_privileged): thread=Thread(user_id="5" if is_thread_author else "6") ) assert can_delete(comment, context) == (is_author or is_privileged) + + +@ddt.ddt +class IsAllowedToRestoreTest(ModuleStoreTestCase): + """Tests for IsAllowedToRestore permission class""" + + def setUp(self): + super().setUp() + self.course = CourseFactory.create() + self.permission = IsAllowedToRestore() + + def _create_mock_request(self, user, course_id): + """Helper to create a mock request object""" + request = Mock() + request.user = user + request.data = {"course_id": str(course_id)} + return request + + def _create_mock_view(self): + """Helper to create a mock view object""" + return Mock() + + def test_unauthenticated_user_denied(self): + """Test that unauthenticated users are denied""" + user = Mock() + user.is_authenticated = False + request = self._create_mock_request(user, self.course.id) + view = self._create_mock_view() + + assert not self.permission.has_permission(request, view) + + def test_missing_course_id_denied(self): + """Test that requests without course_id are denied""" + user = UserFactory.create() + request = Mock() + request.user = user + request.data = {} # No course_id + view = self._create_mock_view() + + assert not self.permission.has_permission(request, view) + + def test_invalid_course_id_denied(self): + """Test that requests with invalid course_id are denied""" + user = UserFactory.create() + request = Mock() + request.user = user + request.data = {"course_id": "invalid-course-id"} + view = self._create_mock_view() + + assert not self.permission.has_permission(request, view) + + def test_global_staff_allowed(self): + """Test that global staff users are allowed""" + user = UserFactory.create(is_staff=True) + request = self._create_mock_request(user, self.course.id) + view = self._create_mock_view() + + assert self.permission.has_permission(request, view) + + def test_course_staff_allowed(self): + """Test that course staff are allowed""" + user = UserFactory.create() + CourseStaffRole(self.course.id).add_users(user) + request = self._create_mock_request(user, self.course.id) + view = self._create_mock_view() + + assert self.permission.has_permission(request, view) + + def test_course_instructor_allowed(self): + """Test that course instructors are allowed""" + user = UserFactory.create() + CourseInstructorRole(self.course.id).add_users(user) + request = self._create_mock_request(user, self.course.id) + view = self._create_mock_view() + + assert self.permission.has_permission(request, view) + + @ddt.data( + FORUM_ROLE_ADMINISTRATOR, + FORUM_ROLE_MODERATOR, + FORUM_ROLE_COMMUNITY_TA, + ) + def test_discussion_privileged_users_allowed(self, role_name): + """Test that discussion privileged users (moderator, community TA, administrator) are allowed""" + user = UserFactory.create() + role = Role.objects.get_or_create(name=role_name, course_id=self.course.id)[0] + role.users.add(user) + request = self._create_mock_request(user, self.course.id) + view = self._create_mock_view() + + assert self.permission.has_permission(request, view) + + def test_regular_user_denied(self): + """Test that regular users without privileges are denied""" + user = UserFactory.create() + request = self._create_mock_request(user, self.course.id) + view = self._create_mock_view() + + assert not self.permission.has_permission(request, view) diff --git a/lms/djangoapps/discussion/rest_api/tests/test_serializers.py b/lms/djangoapps/discussion/rest_api/tests/test_serializers.py index 0cbcc0bebdd1..812bf7a6b9b3 100644 --- a/lms/djangoapps/discussion/rest_api/tests/test_serializers.py +++ b/lms/djangoapps/discussion/rest_api/tests/test_serializers.py @@ -9,19 +9,17 @@ import httpretty from django.test.client import RequestFactory from django.test.utils import override_settings -from xmodule.modulestore import ModuleStoreEnum -from xmodule.modulestore.django import modulestore -from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase -from xmodule.modulestore.tests.factories import CourseFactory from common.djangoapps.student.tests.factories import UserFactory from common.djangoapps.util.testing import UrlResetMixin -from lms.djangoapps.discussion.django_comment_client.tests.utils import ForumsEnableMixin +from lms.djangoapps.discussion.django_comment_client.tests.utils import ( + ForumsEnableMixin, +) from lms.djangoapps.discussion.rest_api.serializers import ( CommentSerializer, ThreadSerializer, filter_spam_urls_from_html, - get_context + get_context, ) from lms.djangoapps.discussion.rest_api.tests.utils import ( CommentsServiceMockMixin, @@ -39,6 +37,10 @@ FORUM_ROLE_STUDENT, Role, ) +from xmodule.modulestore import ModuleStoreEnum +from xmodule.modulestore.django import modulestore +from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase +from xmodule.modulestore.tests.factories import CourseFactory @ddt.ddt @@ -46,13 +48,18 @@ class SerializerTestMixin(ForumsEnableMixin, CommentsServiceMockMixin, UrlResetM """ Test Mixin for Serializer tests """ + @classmethod - @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) def setUpClass(cls): super().setUpClass() cls.course = CourseFactory.create() - @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) def setUp(self): super().setUp() httpretty.reset() @@ -60,8 +67,8 @@ def setUp(self): self.addCleanup(httpretty.reset) self.addCleanup(httpretty.disable) patcher = mock.patch( - 'openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled', - return_value=False + "openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled", + return_value=False, ) patcher.start() self.addCleanup(patcher.stop) @@ -89,7 +96,9 @@ def create_role(self, role_name, users, course=None): (FORUM_ROLE_STUDENT, False, True, True), ) @ddt.unpack - def test_anonymity(self, role_name, anonymous, anonymous_to_peers, expected_serialized_anonymous): + def test_anonymity( + self, role_name, anonymous, anonymous_to_peers, expected_serialized_anonymous + ): """ Test that content is properly made anonymous. @@ -107,7 +116,9 @@ def test_anonymity(self, role_name, anonymous, anonymous_to_peers, expected_seri """ self.create_role(role_name, [self.user]) serialized = self.serialize( - self.make_cs_content({"anonymous": anonymous, "anonymous_to_peers": anonymous_to_peers}) + self.make_cs_content( + {"anonymous": anonymous, "anonymous_to_peers": anonymous_to_peers} + ) ) actual_serialized_anonymous = serialized["author"] is None assert actual_serialized_anonymous == expected_serialized_anonymous @@ -138,17 +149,19 @@ def test_author_labels(self, role_name, anonymous, expected_label): """ self.create_role(role_name, [self.author]) serialized = self.serialize(self.make_cs_content({"anonymous": anonymous})) - assert serialized['author_label'] == expected_label + assert serialized["author_label"] == expected_label def test_abuse_flagged(self): - serialized = self.serialize(self.make_cs_content({"abuse_flaggers": [str(self.user.id)]})) - assert serialized['abuse_flagged'] is True + serialized = self.serialize( + self.make_cs_content({"abuse_flaggers": [str(self.user.id)]}) + ) + assert serialized["abuse_flagged"] is True def test_voted(self): thread_id = "test_thread" self.register_get_user_response(self.user, upvoted_ids=[thread_id]) serialized = self.serialize(self.make_cs_content({"id": thread_id})) - assert serialized['voted'] is True + assert serialized["voted"] is True @ddt.ddt @@ -175,47 +188,62 @@ def serialize(self, thread): Create a serializer with an appropriate context and use it to serialize the given thread, returning the result. """ - return ThreadSerializer(thread, context=get_context(self.course, self.request)).data + return ThreadSerializer( + thread, context=get_context(self.course, self.request) + ).data def test_basic(self): - thread = make_minimal_cs_thread({ - "id": "test_thread", - "course_id": str(self.course.id), - "commentable_id": "test_topic", - "user_id": str(self.author.id), - "username": self.author.username, - "title": "Test Title", - "body": "Test body", - "pinned": True, - "votes": {"up_count": 4}, - "comments_count": 5, - "unread_comments_count": 3, - }) - expected = self.expected_thread_data({ - "author": self.author.username, - "can_delete": False, - "vote_count": 4, - "comment_count": 6, - "unread_comment_count": 3, - "pinned": True, - "editable_fields": ["abuse_flagged", "copy_link", "following", "read", "voted"], - "abuse_flagged_count": None, - "edit_by_label": None, - "closed_by_label": None, - }) + thread = make_minimal_cs_thread( + { + "id": "test_thread", + "course_id": str(self.course.id), + "commentable_id": "test_topic", + "user_id": str(self.author.id), + "username": self.author.username, + "title": "Test Title", + "body": "Test body", + "pinned": True, + "votes": {"up_count": 4}, + "comments_count": 5, + "unread_comments_count": 3, + } + ) + expected = self.expected_thread_data( + { + "author": self.author.username, + "can_delete": False, + "vote_count": 4, + "comment_count": 6, + "unread_comment_count": 3, + "pinned": True, + "editable_fields": [ + "abuse_flagged", + "copy_link", + "following", + "read", + "voted", + ], + "abuse_flagged_count": None, + "edit_by_label": None, + "closed_by_label": None, + "is_deleted": None, + } + ) assert self.serialize(thread) == expected thread["thread_type"] = "question" - expected.update({ - "type": "question", - "comment_list_url": None, - "endorsed_comment_list_url": ( - "http://testserver/api/discussion/v1/comments/?thread_id=test_thread&endorsed=True" - ), - "non_endorsed_comment_list_url": ( - "http://testserver/api/discussion/v1/comments/?thread_id=test_thread&endorsed=False" - ), - }) + expected.update( + { + "type": "question", + "comment_list_url": None, + "endorsed_comment_list_url": ( + "http://testserver/api/discussion/v1/comments/?thread_id=test_thread&endorsed=True" + ), + "non_endorsed_comment_list_url": ( + "http://testserver/api/discussion/v1/comments/?thread_id=test_thread&endorsed=False" + ), + } + ) assert self.serialize(thread) == expected def test_pinned_missing(self): @@ -227,34 +255,34 @@ def test_pinned_missing(self): del thread_data["pinned"] self.register_get_thread_response(thread_data) serialized = self.serialize(thread_data) - assert serialized['pinned'] is False + assert serialized["pinned"] is False def test_group(self): self.course.cohort_config = {"cohorted": True} modulestore().update_item(self.course, ModuleStoreEnum.UserID.test) cohort = CohortFactory.create(course_id=self.course.id) serialized = self.serialize(self.make_cs_content({"group_id": cohort.id})) - assert serialized['group_id'] == cohort.id - assert serialized['group_name'] == cohort.name + assert serialized["group_id"] == cohort.id + assert serialized["group_name"] == cohort.name def test_following(self): thread_id = "test_thread" self.register_get_user_response(self.user, subscribed_thread_ids=[thread_id]) serialized = self.serialize(self.make_cs_content({"id": thread_id})) - assert serialized['following'] is True + assert serialized["following"] is True def test_response_count(self): thread_data = self.make_cs_content({"resp_total": 2}) self.register_get_thread_response(thread_data) serialized = self.serialize(thread_data) - assert serialized['response_count'] == 2 + assert serialized["response_count"] == 2 def test_response_count_missing(self): thread_data = self.make_cs_content({}) del thread_data["resp_total"] self.register_get_thread_response(thread_data) serialized = self.serialize(thread_data) - assert 'response_count' not in serialized + assert "response_count" not in serialized @ddt.data( (FORUM_ROLE_MODERATOR, True), @@ -272,43 +300,62 @@ def test_closed_by_label_field(self, role, visible): self.create_role(FORUM_ROLE_MODERATOR, [moderator]) self.create_role(request_role, [self.user]) - thread = make_minimal_cs_thread({ - "id": "test_thread", - "course_id": str(self.course.id), - "commentable_id": "test_topic", - "user_id": str(author.id), - "username": author.username, - "title": "Test Title", - "body": "Test body", - "pinned": True, - "votes": {"up_count": 4}, - "comments_count": 5, - "unread_comments_count": 3, - "closed_by": moderator - }) + thread = make_minimal_cs_thread( + { + "id": "test_thread", + "course_id": str(self.course.id), + "commentable_id": "test_topic", + "user_id": str(author.id), + "username": author.username, + "title": "Test Title", + "body": "Test body", + "pinned": True, + "votes": {"up_count": 4}, + "comments_count": 5, + "unread_comments_count": 3, + "closed_by": moderator, + } + ) closed_by_label = "Moderator" if visible else None closed_by = moderator if visible else None can_delete = role != FORUM_ROLE_STUDENT editable_fields = ["abuse_flagged", "copy_link", "following", "read", "voted"] if role == "author": editable_fields.remove("voted") - editable_fields.extend(['anonymous', 'raw_body', 'title', 'topic_id', 'type']) + editable_fields.extend( + ["anonymous", "raw_body", "title", "topic_id", "type"] + ) elif role == FORUM_ROLE_MODERATOR: - editable_fields.extend(['close_reason_code', 'closed', 'edit_reason_code', 'pinned', - 'raw_body', 'title', 'topic_id', 'type']) - expected = self.expected_thread_data({ - "author": author.username, - "can_delete": can_delete, - "vote_count": 4, - "comment_count": 6, - "unread_comment_count": 3, - "pinned": True, - "editable_fields": sorted(editable_fields), - "abuse_flagged_count": None, - "edit_by_label": None, - "closed_by_label": closed_by_label, - "closed_by": closed_by, - }) + editable_fields.extend( + [ + "close_reason_code", + "closed", + "edit_reason_code", + "pinned", + "raw_body", + "title", + "topic_id", + "type", + ] + ) + # is_deleted is visible (False) for privileged users and authors, hidden (None) for others + is_deleted = False if role in (FORUM_ROLE_MODERATOR, "author") else None + expected = self.expected_thread_data( + { + "author": author.username, + "can_delete": can_delete, + "vote_count": 4, + "comment_count": 6, + "unread_comment_count": 3, + "pinned": True, + "editable_fields": sorted(editable_fields), + "abuse_flagged_count": None, + "edit_by_label": None, + "closed_by_label": closed_by_label, + "closed_by": closed_by, + "is_deleted": is_deleted, + } + ) assert self.serialize(thread) == expected @ddt.data( @@ -327,48 +374,69 @@ def test_edit_by_label_field(self, role, visible): self.create_role(FORUM_ROLE_MODERATOR, [moderator]) self.create_role(request_role, [self.user]) - thread = make_minimal_cs_thread({ - "id": "test_thread", - "course_id": str(self.course.id), - "commentable_id": "test_topic", - "user_id": str(author.id), - "username": author.username, - "title": "Test Title", - "body": "Test body", - "pinned": True, - "votes": {"up_count": 4}, - "edit_history": [{"editor_username": moderator}], - "comments_count": 5, - "unread_comments_count": 3, - "closed_by": None - }) + thread = make_minimal_cs_thread( + { + "id": "test_thread", + "course_id": str(self.course.id), + "commentable_id": "test_topic", + "user_id": str(author.id), + "username": author.username, + "title": "Test Title", + "body": "Test body", + "pinned": True, + "votes": {"up_count": 4}, + "edit_history": [{"editor_username": moderator}], + "comments_count": 5, + "unread_comments_count": 3, + "closed_by": None, + } + ) edit_by_label = "Moderator" if visible else None can_delete = role != FORUM_ROLE_STUDENT - last_edit = None if role == FORUM_ROLE_STUDENT else {"editor_username": moderator} + last_edit = ( + None if role == FORUM_ROLE_STUDENT else {"editor_username": moderator} + ) editable_fields = ["abuse_flagged", "copy_link", "following", "read", "voted"] if role == "author": editable_fields.remove("voted") - editable_fields.extend(['anonymous', 'raw_body', 'title', 'topic_id', 'type']) + editable_fields.extend( + ["anonymous", "raw_body", "title", "topic_id", "type"] + ) elif role == FORUM_ROLE_MODERATOR: - editable_fields.extend(['close_reason_code', 'closed', 'edit_reason_code', 'pinned', - 'raw_body', 'title', 'topic_id', 'type']) + editable_fields.extend( + [ + "close_reason_code", + "closed", + "edit_reason_code", + "pinned", + "raw_body", + "title", + "topic_id", + "type", + ] + ) - expected = self.expected_thread_data({ - "author": author.username, - "can_delete": can_delete, - "vote_count": 4, - "comment_count": 6, - "unread_comment_count": 3, - "pinned": True, - "editable_fields": sorted(editable_fields), - "abuse_flagged_count": None, - "last_edit": last_edit, - "edit_by_label": edit_by_label, - "closed_by_label": None, - "closed_by": None, - }) + # is_deleted is visible (False) for privileged users and authors, hidden (None) for others + is_deleted = False if role in (FORUM_ROLE_MODERATOR, "author") else None + expected = self.expected_thread_data( + { + "author": author.username, + "can_delete": can_delete, + "vote_count": 4, + "comment_count": 6, + "unread_comment_count": 3, + "pinned": True, + "editable_fields": sorted(editable_fields), + "abuse_flagged_count": None, + "last_edit": last_edit, + "edit_by_label": edit_by_label, + "closed_by_label": None, + "closed_by": None, + "is_deleted": is_deleted, + } + ) assert self.serialize(thread) == expected def test_get_preview_body(self): @@ -384,7 +452,10 @@ def test_get_preview_body(self): {"body": "

This is a test thread body with some text.

"} ) serialized = self.serialize(thread_data) - assert serialized['preview_body'] == "This is a test thread body with some text." + assert ( + serialized["preview_body"] + == "This is a test thread body with some text." + ) @ddt.ddt @@ -402,12 +473,12 @@ def make_cs_content(self, overrides=None, with_endorsement=False): """ merged_overrides = { "user_id": str(self.author.id), - "username": self.author.username + "username": self.author.username, } if with_endorsement: merged_overrides["endorsement"] = { "user_id": str(self.endorser.id), - "time": self.endorsed_at + "time": self.endorsed_at, } merged_overrides.update(overrides or {}) return make_minimal_cs_comment(merged_overrides) @@ -417,7 +488,9 @@ def serialize(self, comment, thread_data=None): Create a serializer with an appropriate context and use it to serialize the given comment, returning the result. """ - context = get_context(self.course, self.request, make_minimal_cs_thread(thread_data)) + context = get_context( + self.course, self.request, make_minimal_cs_thread(thread_data) + ) return CommentSerializer(comment, context=context).data def test_basic(self): @@ -446,6 +519,8 @@ def test_basic(self): "parent_id": None, "author": self.author.username, "author_label": None, + "is_author_banned": False, + "author_ban_scope": None, "created_at": "2015-04-28T00:00:00Z", "updated_at": "2015-04-28T11:11:11Z", "raw_body": "Test body", @@ -464,6 +539,7 @@ def test_basic(self): "can_delete": False, "last_edit": None, "edit_by_label": None, + "learner_status": "new", "profile_image": { "has_image": False, "image_url_full": "http://testserver/static/default_500.png", @@ -471,6 +547,10 @@ def test_basic(self): "image_url_medium": "http://testserver/static/default_50.png", "image_url_small": "http://testserver/static/default_30.png", }, + "is_deleted": None, + "deleted_at": None, + "deleted_by": None, + "deleted_by_label": None, } assert self.serialize(comment) == expected @@ -483,7 +563,7 @@ def test_basic(self): FORUM_ROLE_COMMUNITY_TA, FORUM_ROLE_STUDENT, ], - [True, False] + [True, False], ) ) @ddt.unpack @@ -500,10 +580,12 @@ def test_endorsed_by(self, endorser_role_name, thread_anonymous): self.create_role(endorser_role_name, [self.endorser]) serialized = self.serialize( self.make_cs_content(with_endorsement=True), - thread_data={"anonymous": thread_anonymous} + thread_data={"anonymous": thread_anonymous}, ) actual_endorser_anonymous = serialized["endorsed_by"] is None - expected_endorser_anonymous = endorser_role_name == FORUM_ROLE_STUDENT and thread_anonymous + expected_endorser_anonymous = ( + endorser_role_name == FORUM_ROLE_STUDENT and thread_anonymous + ) assert actual_endorser_anonymous == expected_endorser_anonymous @ddt.data( @@ -526,56 +608,106 @@ def test_endorsed_by_labels(self, role_name, expected_label): """ self.create_role(role_name, [self.endorser]) serialized = self.serialize(self.make_cs_content(with_endorsement=True)) - assert serialized['endorsed_by_label'] == expected_label + assert serialized["endorsed_by_label"] == expected_label def test_endorsed_at(self): serialized = self.serialize(self.make_cs_content(with_endorsement=True)) - assert serialized['endorsed_at'] == self.endorsed_at + assert serialized["endorsed_at"] == self.endorsed_at def test_children(self): - comment = self.make_cs_content({ - "id": "test_root", - "children": [ - self.make_cs_content({ - "id": "test_child_1", - "parent_id": "test_root", - }), - self.make_cs_content({ - "id": "test_child_2", - "parent_id": "test_root", - "children": [ - self.make_cs_content({ - "id": "test_grandchild", - "parent_id": "test_child_2" - }) - ], - }), - ], - }) + comment = self.make_cs_content( + { + "id": "test_root", + "children": [ + self.make_cs_content( + { + "id": "test_child_1", + "parent_id": "test_root", + } + ), + self.make_cs_content( + { + "id": "test_child_2", + "parent_id": "test_root", + "children": [ + self.make_cs_content( + { + "id": "test_grandchild", + "parent_id": "test_child_2", + } + ) + ], + } + ), + ], + } + ) serialized = self.serialize(comment) - assert serialized['children'][0]['id'] == 'test_child_1' - assert serialized['children'][0]['parent_id'] == 'test_root' - assert serialized['children'][1]['id'] == 'test_child_2' - assert serialized['children'][1]['parent_id'] == 'test_root' - assert serialized['children'][1]['children'][0]['id'] == 'test_grandchild' - assert serialized['children'][1]['children'][0]['parent_id'] == 'test_child_2' + assert serialized["children"][0]["id"] == "test_child_1" + assert serialized["children"][0]["parent_id"] == "test_root" + assert serialized["children"][1]["id"] == "test_child_2" + assert serialized["children"][1]["parent_id"] == "test_root" + assert serialized["children"][1]["children"][0]["id"] == "test_grandchild" + assert serialized["children"][1]["children"][0]["parent_id"] == "test_child_2" + + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) + def test_ban_lookup_is_memoized_for_duplicate_authors(self): + comment_1 = self.make_cs_content({"id": "test_comment_1"}) + comment_2 = self.make_cs_content({"id": "test_comment_2"}) + context = get_context( + self.course, + self.request, + make_minimal_cs_thread({"id": "test_thread"}), + ) + + with mock.patch( + "lms.djangoapps.discussion.toggles.ENABLE_DISCUSSION_BAN.is_enabled", + return_value=True, + ), mock.patch( + "forum.api.is_user_banned", + return_value=True, + create=True, + ) as is_user_banned_mock, mock.patch( + "forum.api.get_user_bans", + return_value=[{"is_active": True, "scope": "course"}], + create=True, + ) as get_user_bans_mock: + serialized = CommentSerializer( + [comment_1, comment_2], + context=context, + many=True, + ).data + + assert serialized[0]["is_author_banned"] is True + assert serialized[0]["author_ban_scope"] == "course" + assert serialized[1]["is_author_banned"] is True + assert serialized[1]["author_ban_scope"] == "course" + assert is_user_banned_mock.call_count == 1 + assert get_user_bans_mock.call_count == 1 @ddt.ddt class ThreadSerializerDeserializationTest( - ForumsEnableMixin, - CommentsServiceMockMixin, - UrlResetMixin, - SharedModuleStoreTestCase + ForumsEnableMixin, + CommentsServiceMockMixin, + UrlResetMixin, + SharedModuleStoreTestCase, ): """Tests for ThreadSerializer deserialization.""" + @classmethod - @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) def setUpClass(cls): super().setUpClass() cls.course = CourseFactory.create() - @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) def setUp(self): super().setUp() httpretty.reset() @@ -583,8 +715,8 @@ def setUp(self): self.addCleanup(httpretty.reset) self.addCleanup(httpretty.disable) patcher = mock.patch( - 'openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled', - return_value=False + "openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled", + return_value=False, ) patcher.start() self.addCleanup(patcher.stop) @@ -599,18 +731,22 @@ def setUp(self): "title": "Test Title", "raw_body": "Test body", } - self.existing_thread = Thread(**make_minimal_cs_thread({ - "id": "existing_thread", - "course_id": str(self.course.id), - "commentable_id": "original_topic", - "thread_type": "discussion", - "title": "Original Title", - "body": "Original body", - "user_id": str(self.user.id), - "username": self.user.username, - "read": "False", - "endorsed": "False" - })) + self.existing_thread = Thread( + **make_minimal_cs_thread( + { + "id": "existing_thread", + "course_id": str(self.course.id), + "commentable_id": "original_topic", + "thread_type": "discussion", + "title": "Original Title", + "body": "Original body", + "user_id": str(self.user.id), + "username": self.user.username, + "read": "False", + "endorsed": "False", + } + ) + ) def save_and_reserialize(self, data, instance=None): """ @@ -622,7 +758,7 @@ def save_and_reserialize(self, data, instance=None): instance, data=data, partial=(instance is not None), - context=get_context(self.course, self.request) + context=get_context(self.course, self.request), ) assert serializer.is_valid() serializer.save() @@ -634,33 +770,36 @@ def test_create_missing_field(self): data.pop(field) serializer = ThreadSerializer(data=data) assert not serializer.is_valid() - assert serializer.errors == {field: ['This field is required.']} + assert serializer.errors == {field: ["This field is required."]} @ddt.data("", " ") def test_create_empty_string(self, value): data = self.minimal_data.copy() data.update({field: value for field in ["topic_id", "title", "raw_body"]}) - serializer = ThreadSerializer(data=data, context=get_context(self.course, self.request)) + serializer = ThreadSerializer( + data=data, context=get_context(self.course, self.request) + ) assert not serializer.is_valid() assert serializer.errors == { - field: ['This field may not be blank.'] for field in ['topic_id', 'title', 'raw_body'] + field: ["This field may not be blank."] + for field in ["topic_id", "title", "raw_body"] } def test_update_empty(self): self.register_put_thread_response(self.existing_thread.attributes) self.save_and_reserialize({}, self.existing_thread) assert parsed_body(httpretty.last_request()) == { - 'course_id': [str(self.course.id)], - 'commentable_id': ['original_topic'], - 'thread_type': ['discussion'], - 'title': ['Original Title'], - 'body': ['Original body'], - 'anonymous': ['False'], - 'anonymous_to_peers': ['False'], - 'closed': ['False'], - 'pinned': ['False'], - 'user_id': [str(self.user.id)], - 'read': ['False'] + "course_id": [str(self.course.id)], + "commentable_id": ["original_topic"], + "thread_type": ["discussion"], + "title": ["Original Title"], + "body": ["Original body"], + "anonymous": ["False"], + "anonymous_to_peers": ["False"], + "closed": ["False"], + "pinned": ["False"], + "user_id": [str(self.user.id)], + "read": ["False"], } @ddt.data(True, False) @@ -675,18 +814,18 @@ def test_update_all(self, read): } saved = self.save_and_reserialize(data, self.existing_thread) assert parsed_body(httpretty.last_request()) == { - 'course_id': [str(self.course.id)], - 'commentable_id': ['edited_topic'], - 'thread_type': ['question'], - 'title': ['Edited Title'], - 'body': ['Edited body'], - 'anonymous': ['False'], - 'anonymous_to_peers': ['False'], - 'closed': ['False'], - 'pinned': ['False'], - 'user_id': [str(self.user.id)], - 'read': [str(read)], - 'editing_user_id': [str(self.user.id)], + "course_id": [str(self.course.id)], + "commentable_id": ["edited_topic"], + "thread_type": ["question"], + "title": ["Edited Title"], + "body": ["Edited body"], + "anonymous": ["False"], + "anonymous_to_peers": ["False"], + "closed": ["False"], + "pinned": ["False"], + "user_id": [str(self.user.id)], + "read": [str(read)], + "editing_user_id": [str(self.user.id)], } for key in data: assert saved[key] == data[key] @@ -701,7 +840,7 @@ def test_update_anonymous(self): "anonymous": True, } self.save_and_reserialize(data, self.existing_thread) - assert parsed_body(httpretty.last_request())["anonymous"] == ['True'] + assert parsed_body(httpretty.last_request())["anonymous"] == ["True"] def test_update_anonymous_to_peers(self): """ @@ -713,7 +852,7 @@ def test_update_anonymous_to_peers(self): "anonymous_to_peers": True, } self.save_and_reserialize(data, self.existing_thread) - assert parsed_body(httpretty.last_request())["anonymous_to_peers"] == ['True'] + assert parsed_body(httpretty.last_request())["anonymous_to_peers"] == ["True"] @ddt.data("", " ") def test_update_empty_string(self, value): @@ -721,11 +860,12 @@ def test_update_empty_string(self, value): self.existing_thread, data={field: value for field in ["topic_id", "title", "raw_body"]}, partial=True, - context=get_context(self.course, self.request) + context=get_context(self.course, self.request), ) assert not serializer.is_valid() assert serializer.errors == { - field: ['This field may not be blank.'] for field in ['topic_id', 'title', 'raw_body'] + field: ["This field may not be blank."] + for field in ["topic_id", "title", "raw_body"] } def test_update_course_id(self): @@ -733,15 +873,20 @@ def test_update_course_id(self): self.existing_thread, data={"course_id": "some/other/course"}, partial=True, - context=get_context(self.course, self.request) + context=get_context(self.course, self.request), ) assert not serializer.is_valid() - assert serializer.errors == {'course_id': ['This field is not allowed in an update.']} + assert serializer.errors == { + "course_id": ["This field is not allowed in an update."] + } @ddt.ddt -class CommentSerializerDeserializationTest(ForumsEnableMixin, CommentsServiceMockMixin, SharedModuleStoreTestCase): +class CommentSerializerDeserializationTest( + ForumsEnableMixin, CommentsServiceMockMixin, SharedModuleStoreTestCase +): """Tests for ThreadSerializer deserialization.""" + @classmethod def setUpClass(cls): super().setUpClass() @@ -754,8 +899,8 @@ def setUp(self): self.addCleanup(httpretty.reset) self.addCleanup(httpretty.disable) patcher = mock.patch( - 'openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled', - return_value=False + "openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled", + return_value=False, ) patcher.start() self.addCleanup(patcher.stop) @@ -777,14 +922,18 @@ def setUp(self): "thread_id": "test_thread", "raw_body": "Test body", } - self.existing_comment = Comment(**make_minimal_cs_comment({ - "id": "existing_comment", - "thread_id": "dummy", - "body": "Original body", - "user_id": str(self.user.id), - "username": self.user.username, - "course_id": str(self.course.id), - })) + self.existing_comment = Comment( + **make_minimal_cs_comment( + { + "id": "existing_comment", + "thread_id": "dummy", + "body": "Original body", + "user_id": str(self.user.id), + "username": self.user.username, + "course_id": str(self.course.id), + } + ) + ) def save_and_reserialize(self, data, instance=None): """ @@ -794,13 +943,10 @@ def save_and_reserialize(self, data, instance=None): context = get_context( self.course, self.request, - make_minimal_cs_thread({"course_id": str(self.course.id)}) + make_minimal_cs_thread({"course_id": str(self.course.id)}), ) serializer = CommentSerializer( - instance, - data=data, - partial=(instance is not None), - context=context + instance, data=data, partial=(instance is not None), context=context ) assert serializer.is_valid() serializer.save() @@ -812,21 +958,23 @@ def test_create_missing_field(self): data.pop(field) serializer = CommentSerializer( data=data, - context=get_context(self.course, self.request, make_minimal_cs_thread()) + context=get_context( + self.course, self.request, make_minimal_cs_thread() + ), ) assert not serializer.is_valid() - assert serializer.errors == {field: ['This field is required.']} + assert serializer.errors == {field: ["This field is required."]} def test_update_empty(self): self.register_put_comment_response(self.existing_comment.attributes) self.save_and_reserialize({}, instance=self.existing_comment) assert parsed_body(httpretty.last_request()) == { - 'body': ['Original body'], - 'course_id': [str(self.course.id)], - 'user_id': [str(self.user.id)], - 'anonymous': ['False'], - 'anonymous_to_peers': ['False'], - 'endorsed': ['False'] + "body": ["Original body"], + "course_id": [str(self.course.id)], + "user_id": [str(self.user.id)], + "anonymous": ["False"], + "anonymous_to_peers": ["False"], + "endorsed": ["False"], } def test_update_anonymous(self): @@ -839,7 +987,7 @@ def test_update_anonymous(self): "anonymous": True, } self.save_and_reserialize(data, self.existing_comment) - assert parsed_body(httpretty.last_request())["anonymous"] == ['True'] + assert parsed_body(httpretty.last_request())["anonymous"] == ["True"] def test_update_anonymous_to_peers(self): """ @@ -851,7 +999,7 @@ def test_update_anonymous_to_peers(self): "anonymous_to_peers": True, } self.save_and_reserialize(data, self.existing_comment) - assert parsed_body(httpretty.last_request())["anonymous_to_peers"] == ['True'] + assert parsed_body(httpretty.last_request())["anonymous_to_peers"] == ["True"] @ddt.data("thread_id", "parent_id") def test_update_non_updatable(self, field): @@ -859,23 +1007,26 @@ def test_update_non_updatable(self, field): self.existing_comment, data={field: "different_value"}, partial=True, - context=get_context(self.course, self.request) + context=get_context(self.course, self.request), ) assert not serializer.is_valid() - assert serializer.errors == {field: ['This field is not allowed in an update.']} + assert serializer.errors == {field: ["This field is not allowed in an update."]} class FilterSpamTest(SharedModuleStoreTestCase): """ Tests for the filter_spam method """ - @override_settings(DISCUSSION_SPAM_URLS=['example.com']) + + @override_settings(DISCUSSION_SPAM_URLS=["example.com"]) def test_filter(self): self.assertEqual( - filter_spam_urls_from_html('')[0], - '
abc
' + filter_spam_urls_from_html( + '' + )[0], + "
abc
", ) self.assertEqual( - filter_spam_urls_from_html('
example.com/abc/def
')[0], - '
' + filter_spam_urls_from_html("
example.com/abc/def
")[0], + "
", ) diff --git a/lms/djangoapps/discussion/rest_api/tests/test_tasks_v2.py b/lms/djangoapps/discussion/rest_api/tests/test_tasks_v2.py index 153ba156049d..e14da2dfa851 100644 --- a/lms/djangoapps/discussion/rest_api/tests/test_tasks_v2.py +++ b/lms/djangoapps/discussion/rest_api/tests/test_tasks_v2.py @@ -14,6 +14,7 @@ from common.djangoapps.student.tests.factories import StaffFactory, UserFactory from lms.djangoapps.discussion.django_comment_client.tests.factories import RoleFactory from lms.djangoapps.discussion.rest_api.tasks import ( + delete_course_post_for_user, send_response_endorsed_notifications, send_response_notifications, send_thread_created_notification @@ -802,3 +803,50 @@ def test_response_endorsed_notifications(self): self.assertEqual(notification_data.content_url, _get_mfe_url(self.course.id, thread.id)) self.assertEqual(notification_data.app_name, 'discussion') self.assertEqual('response_endorsed', notification_data.notification_type) + + +class TestDeleteCoursePostForUserTask(ModuleStoreTestCase): + """Tests for delete_course_post_for_user task behavior.""" + + def setUp(self): + super().setUp() + self.course = CourseFactory.create() + self.target_user = UserFactory.create() + self.moderator = UserFactory.create() + + def test_ban_succeeds_when_email_send_fails(self): + """Ban should still succeed even if escalation email raises an exception.""" + with mock.patch( + 'lms.djangoapps.discussion.rest_api.tasks.Thread.delete_user_threads', + return_value=2, + ), mock.patch( + 'lms.djangoapps.discussion.rest_api.tasks.Comment.delete_user_comments', + return_value=3, + ), mock.patch( + 'forum.api.ban_user', + return_value={'id': 42}, + create=True, + ) as mock_ban_user, mock.patch( + 'lms.djangoapps.discussion.rest_api.emails.send_ban_escalation_email', + side_effect=Exception('email failure'), + ) as mock_send_email, mock.patch( + 'lms.djangoapps.discussion.rest_api.tasks.tracker.emit', + ), mock.patch( + 'lms.djangoapps.discussion.rest_api.tasks.segment.track', + ): + result = delete_course_post_for_user.run( + user_id=self.target_user.id, + username=self.target_user.username, + course_ids=[str(self.course.id)], + event_data={'triggered_by_user_id': self.moderator.id}, + ban_user=True, + ban_scope='course', + moderator_id=self.moderator.id, + reason='test reason', + ) + + mock_ban_user.assert_called_once() + mock_send_email.assert_called_once() + self.assertTrue(result['ban_created']) + self.assertEqual(result['ban_id'], 42) + self.assertIsNone(result['ban_error']) diff --git a/lms/djangoapps/discussion/rest_api/tests/test_views.py b/lms/djangoapps/discussion/rest_api/tests/test_views.py index be8a793abc92..3b4b6ff47d5e 100644 --- a/lms/djangoapps/discussion/rest_api/tests/test_views.py +++ b/lms/djangoapps/discussion/rest_api/tests/test_views.py @@ -2,7 +2,6 @@ Tests for Discussion API views """ - import json import random from datetime import datetime @@ -20,22 +19,22 @@ from rest_framework import status from rest_framework.test import APIClient, APITestCase -from lms.djangoapps.discussion.toggles import ENABLE_DISCUSSIONS_MFE -from lms.djangoapps.discussion.rest_api.utils import get_usernames_from_search_string -from xmodule.modulestore import ModuleStoreEnum -from xmodule.modulestore.django import modulestore -from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase, SharedModuleStoreTestCase -from xmodule.modulestore.tests.factories import CourseFactory, BlockFactory, check_mongo_calls - from common.djangoapps.course_modes.models import CourseMode from common.djangoapps.course_modes.tests.factories import CourseModeFactory -from common.djangoapps.student.models import get_retired_username_by_username, CourseEnrollment -from common.djangoapps.student.roles import CourseInstructorRole, CourseStaffRole, GlobalStaff +from common.djangoapps.student.models import ( + CourseEnrollment, + get_retired_username_by_username, +) +from common.djangoapps.student.roles import ( + CourseInstructorRole, + CourseStaffRole, + GlobalStaff, +) from common.djangoapps.student.tests.factories import ( AdminFactory, CourseEnrollmentFactory, SuperuserFactory, - UserFactory + UserFactory, ) from common.djangoapps.util.testing import UrlResetMixin from lms.djangoapps.discussion.django_comment_client.tests.utils import ( @@ -48,21 +47,57 @@ make_minimal_cs_comment, make_minimal_cs_thread, ) +from lms.djangoapps.discussion.rest_api.serializers import ( + BulkDeleteBanRequestSerializer, +) +from lms.djangoapps.discussion.rest_api.utils import get_usernames_from_search_string +from lms.djangoapps.discussion.rest_api.views import DiscussionModerationViewSet +from lms.djangoapps.discussion.toggles import ENABLE_DISCUSSIONS_MFE from openedx.core.djangoapps.course_groups.tests.helpers import config_course_cohorts -from openedx.core.djangoapps.discussions.config.waffle import ENABLE_NEW_STRUCTURE_DISCUSSIONS -from openedx.core.djangoapps.discussions.models import DiscussionsConfiguration, DiscussionTopicLink, Provider -from openedx.core.djangoapps.discussions.tasks import update_discussions_settings_from_course_task +from openedx.core.djangoapps.discussions.config.waffle import ( + ENABLE_NEW_STRUCTURE_DISCUSSIONS, +) +from openedx.core.djangoapps.discussions.models import ( + DiscussionsConfiguration, + DiscussionTopicLink, + Provider, +) +from openedx.core.djangoapps.discussions.tasks import ( + update_discussions_settings_from_course_task, +) from openedx.core.djangoapps.django_comment_common.models import ( CourseDiscussionSettings, Role, ) from openedx.core.djangoapps.django_comment_common.utils import seed_permissions_roles from openedx.core.djangoapps.oauth_dispatch.jwt import create_jwt_for_user -from openedx.core.djangoapps.oauth_dispatch.tests.factories import AccessTokenFactory, ApplicationFactory -from openedx.core.djangoapps.user_api.models import RetirementState, UserRetirementStatus +from openedx.core.djangoapps.oauth_dispatch.tests.factories import ( + AccessTokenFactory, + ApplicationFactory, +) +from openedx.core.djangoapps.user_api.models import ( + RetirementState, + UserRetirementStatus, +) +from openedx.core.djangoapps.django_comment_common.comment_client.utils import ( + CommentClientRequestError, +) +from xmodule.modulestore import ModuleStoreEnum +from xmodule.modulestore.django import modulestore +from xmodule.modulestore.tests.django_utils import ( + ModuleStoreTestCase, + SharedModuleStoreTestCase, +) +from xmodule.modulestore.tests.factories import ( + BlockFactory, + CourseFactory, + check_mongo_calls, +) -class DiscussionAPIViewTestMixin(ForumsEnableMixin, CommentsServiceMockMixin, UrlResetMixin): +class DiscussionAPIViewTestMixin( + ForumsEnableMixin, CommentsServiceMockMixin, UrlResetMixin +): """ Mixin for common code in tests of Discussion API views. This includes creation of common structures (e.g. a course, user, and enrollment), logging @@ -72,7 +107,9 @@ class DiscussionAPIViewTestMixin(ForumsEnableMixin, CommentsServiceMockMixin, Ur client_class = APIClient - @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) def setUp(self): super().setUp() self.maxDiff = None # pylint: disable=invalid-name @@ -81,7 +118,7 @@ def setUp(self): course="y", run="z", start=datetime.now(UTC), - discussion_topics={"Test Topic": {"id": "test_topic"}} + discussion_topics={"Test Topic": {"id": "test_topic"}}, ) self.password = "Password1234" self.user = UserFactory.create(password=self.password) @@ -96,23 +133,25 @@ def assert_response_correct(self, response, expected_status, expected_content): Assert that the response has the given status code and parsed content """ assert response.status_code == expected_status - parsed_content = json.loads(response.content.decode('utf-8')) + parsed_content = json.loads(response.content.decode("utf-8")) assert parsed_content == expected_content def register_thread(self, overrides=None): """ Create cs_thread with minimal fields and register response """ - cs_thread = make_minimal_cs_thread({ - "id": "test_thread", - "course_id": str(self.course.id), - "commentable_id": "test_topic", - "username": self.user.username, - "user_id": str(self.user.id), - "thread_type": "discussion", - "title": "Test Title", - "body": "Test body", - }) + cs_thread = make_minimal_cs_thread( + { + "id": "test_thread", + "course_id": str(self.course.id), + "commentable_id": "test_topic", + "username": self.user.username, + "user_id": str(self.user.id), + "thread_type": "discussion", + "title": "Test Title", + "body": "Test body", + } + ) cs_thread.update(overrides or {}) self.register_get_thread_response(cs_thread) self.register_put_thread_response(cs_thread) @@ -121,14 +160,16 @@ def register_comment(self, overrides=None): """ Create cs_comment with minimal fields and register response """ - cs_comment = make_minimal_cs_comment({ - "id": "test_comment", - "course_id": str(self.course.id), - "thread_id": "test_thread", - "username": self.user.username, - "user_id": str(self.user.id), - "body": "Original body", - }) + cs_comment = make_minimal_cs_comment( + { + "id": "test_comment", + "course_id": str(self.course.id), + "thread_id": "test_thread", + "username": self.user.username, + "user_id": str(self.user.id), + "body": "Original body", + } + ) cs_comment.update(overrides or {}) self.register_get_comment_response(cs_comment) self.register_put_comment_response(cs_comment) @@ -140,7 +181,7 @@ def test_not_authenticated(self): self.assert_response_correct( response, 401, - {"developer_message": "Authentication credentials were not provided."} + {"developer_message": "Authentication credentials were not provided."}, ) def test_inactive(self): @@ -149,12 +190,16 @@ def test_inactive(self): @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) -class UploadFileViewTest(ForumsEnableMixin, CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestCase): +class UploadFileViewTest( + ForumsEnableMixin, CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestCase +): """ Tests for UploadFileView. """ - @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) def setUp(self): super().setUp() self.valid_file = { @@ -165,11 +210,13 @@ def setUp(self): ), } self.user = UserFactory.create(password=self.TEST_PASSWORD) - self.course = CourseFactory.create(org='a', course='b', run='c', start=datetime.now(UTC)) + self.course = CourseFactory.create( + org="a", course="b", run="c", start=datetime.now(UTC) + ) self.url = reverse("upload_file", kwargs={"course_id": str(self.course.id)}) patcher = mock.patch( - 'openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled', - return_value=False + "openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled", + return_value=False, ) patcher.start() self.addCleanup(patcher.stop) @@ -257,10 +304,13 @@ def test_file_upload_with_thread_key(self): """ self.user_login() self.enroll_user_in_course() - response = self.client.post(self.url, { - **self.valid_file, - "thread_key": "somethread", - }) + response = self.client.post( + self.url, + { + **self.valid_file, + "thread_key": "somethread", + }, + ) response_data = json.loads(response.content) assert "/somethread/" in response_data["location"] @@ -314,7 +364,9 @@ class CommentViewSetListByUserTest( Common test cases for views retrieving user-published content. """ - @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) def setUp(self): super().setUp() @@ -323,8 +375,8 @@ def setUp(self): self.addCleanup(httpretty.reset) self.addCleanup(httpretty.disable) patcher = mock.patch( - 'openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled', - return_value=False + "openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled", + return_value=False, ) patcher.start() self.addCleanup(patcher.stop) @@ -335,7 +387,9 @@ def setUp(self): self.other_user = UserFactory.create(password=self.TEST_PASSWORD) self.register_get_user_response(self.other_user) - self.course = CourseFactory.create(org="a", course="b", run="c", start=datetime.now(UTC)) + self.course = CourseFactory.create( + org="a", course="b", run="c", start=datetime.now(UTC) + ) CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id) self.url = self.build_url(self.user.username, self.course.id) @@ -346,16 +400,18 @@ def register_mock_endpoints(self): """ self.register_get_threads_response( threads=[ - make_minimal_cs_thread({ - "id": f"test_thread_{index}", - "course_id": str(self.course.id), - "commentable_id": f"test_topic_{index}", - "username": self.user.username, - "user_id": str(self.user.id), - "thread_type": "discussion", - "title": f"Test Title #{index}", - "body": f"Test body #{index}", - }) + make_minimal_cs_thread( + { + "id": f"test_thread_{index}", + "course_id": str(self.course.id), + "commentable_id": f"test_topic_{index}", + "username": self.user.username, + "user_id": str(self.user.id), + "thread_type": "discussion", + "title": f"Test Title #{index}", + "body": f"Test body #{index}", + } + ) for index in range(30) ], page=1, @@ -363,16 +419,18 @@ def register_mock_endpoints(self): ) self.register_get_comments_response( comments=[ - make_minimal_cs_comment({ - "id": f"test_comment_{index}", - "thread_id": "test_thread", - "user_id": str(self.user.id), - "username": self.user.username, - "created_at": "2015-05-11T00:00:00Z", - "updated_at": "2015-05-11T11:11:11Z", - "body": f"Test body #{index}", - "votes": {"up_count": 4}, - }) + make_minimal_cs_comment( + { + "id": f"test_comment_{index}", + "thread_id": "test_thread", + "user_id": str(self.user.id), + "username": self.user.username, + "created_at": "2015-05-11T00:00:00Z", + "updated_at": "2015-05-11T11:11:11Z", + "body": f"Test body #{index}", + "votes": {"up_count": 4}, + } + ) for index in range(30) ], page=1, @@ -384,11 +442,13 @@ def build_url(self, username, course_id, **kwargs): Builds an URL to access content from an user on a specific course. """ base = reverse("comment-list") - query = urlencode({ - "username": username, - "course_id": str(course_id), - **kwargs, - }) + query = urlencode( + { + "username": username, + "course_id": str(course_id), + **kwargs, + } + ) return f"{base}?{query}" def assert_successful_response(self, response): @@ -414,7 +474,9 @@ def test_request_by_unauthorized_user(self): they're not either enrolled or staff members. """ self.register_mock_endpoints() - self.client.login(username=self.other_user.username, password=self.TEST_PASSWORD) + self.client.login( + username=self.other_user.username, password=self.TEST_PASSWORD + ) response = self.client.get(self.url) assert response.status_code == status.HTTP_404_NOT_FOUND assert json.loads(response.content)["developer_message"] == "Course not found." @@ -425,7 +487,9 @@ def test_request_by_enrolled_user(self): comments in that course. """ self.register_mock_endpoints() - self.client.login(username=self.other_user.username, password=self.TEST_PASSWORD) + self.client.login( + username=self.other_user.username, password=self.TEST_PASSWORD + ) CourseEnrollmentFactory.create(user=self.other_user, course_id=self.course.id) self.assert_successful_response(self.client.get(self.url)) @@ -434,7 +498,9 @@ def test_request_by_global_staff(self): Staff users are allowed to get any user's comments. """ self.register_mock_endpoints() - self.client.login(username=self.other_user.username, password=self.TEST_PASSWORD) + self.client.login( + username=self.other_user.username, password=self.TEST_PASSWORD + ) GlobalStaff().add_users(self.other_user) self.assert_successful_response(self.client.get(self.url)) @@ -445,7 +511,9 @@ def test_request_by_course_staff(self, role): course. """ self.register_mock_endpoints() - self.client.login(username=self.other_user.username, password=self.TEST_PASSWORD) + self.client.login( + username=self.other_user.username, password=self.TEST_PASSWORD + ) role(course_key=self.course.id).add_users(self.other_user) self.assert_successful_response(self.client.get(self.url)) @@ -454,7 +522,9 @@ def test_request_with_non_existent_user(self): Requests for users that don't exist result in a 404 response. """ self.register_mock_endpoints() - self.client.login(username=self.other_user.username, password=self.TEST_PASSWORD) + self.client.login( + username=self.other_user.username, password=self.TEST_PASSWORD + ) GlobalStaff().add_users(self.other_user) url = self.build_url("non_existent", self.course.id) response = self.client.get(url) @@ -465,7 +535,9 @@ def test_request_with_non_existent_course(self): Requests for courses that don't exist result in a 404 response. """ self.register_mock_endpoints() - self.client.login(username=self.other_user.username, password=self.TEST_PASSWORD) + self.client.login( + username=self.other_user.username, password=self.TEST_PASSWORD + ) GlobalStaff().add_users(self.other_user) url = self.build_url(self.user.username, "course-v1:x+y+z") response = self.client.get(url) @@ -476,14 +548,18 @@ def test_request_with_invalid_course_id(self): Requests with invalid course ID should fail form validation. """ self.register_mock_endpoints() - self.client.login(username=self.other_user.username, password=self.TEST_PASSWORD) + self.client.login( + username=self.other_user.username, password=self.TEST_PASSWORD + ) GlobalStaff().add_users(self.other_user) url = self.build_url(self.user.username, "an invalid course") response = self.client.get(url) assert response.status_code == status.HTTP_400_BAD_REQUEST parsed_response = json.loads(response.content) - assert parsed_response["field_errors"]["course_id"]["developer_message"] == \ - "'an invalid course' is not a valid course id" + assert ( + parsed_response["field_errors"]["course_id"]["developer_message"] + == "'an invalid course' is not a valid course id" + ) def test_request_with_empty_results_page(self): """ @@ -493,7 +569,9 @@ def test_request_with_empty_results_page(self): self.register_get_threads_response(threads=[], page=1, num_pages=1) self.register_get_comments_response(comments=[], page=1, num_pages=1) - self.client.login(username=self.other_user.username, password=self.TEST_PASSWORD) + self.client.login( + username=self.other_user.username, password=self.TEST_PASSWORD + ) GlobalStaff().add_users(self.other_user) url = self.build_url(self.user.username, self.course.id, page=2) response = self.client.get(url) @@ -501,17 +579,23 @@ def test_request_with_empty_results_page(self): @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) -@override_settings(DISCUSSION_MODERATION_EDIT_REASON_CODES={"test-edit-reason": "Test Edit Reason"}) -@override_settings(DISCUSSION_MODERATION_CLOSE_REASON_CODES={"test-close-reason": "Test Close Reason"}) +@override_settings( + DISCUSSION_MODERATION_EDIT_REASON_CODES={"test-edit-reason": "Test Edit Reason"} +) +@override_settings( + DISCUSSION_MODERATION_CLOSE_REASON_CODES={"test-close-reason": "Test Close Reason"} +) class CourseViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): """Tests for CourseView""" def setUp(self): super().setUp() - self.url = reverse("discussion_course", kwargs={"course_id": str(self.course.id)}) + self.url = reverse( + "discussion_course", kwargs={"course_id": str(self.course.id)} + ) patcher = mock.patch( - 'openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled', - return_value=False + "openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled", + return_value=False, ) patcher.start() self.addCleanup(patcher.stop) @@ -521,9 +605,7 @@ def test_404(self): reverse("course_topics", kwargs={"course_id": "non/existent/course"}) ) self.assert_response_correct( - response, - 404, - {"developer_message": "Course not found."} + response, 404, {"developer_message": "Course not found."} ) def test_basic(self): @@ -534,6 +616,7 @@ def test_basic(self): { "id": str(self.course.id), "is_posting_enabled": True, + "is_user_banned": False, "blackouts": [], "thread_list_url": "http://testserver/api/discussion/v1/threads/?course_id=course-v1%3Ax%2By%2Bz", "following_thread_list_url": ( @@ -547,23 +630,28 @@ def test_basic(self): "allow_anonymous_to_peers": False, "has_bulk_delete_privileges": False, "has_moderation_privileges": False, - 'is_course_admin': False, - 'is_course_staff': False, + "is_course_admin": False, + "is_course_staff": False, "is_group_ta": False, - 'is_user_admin': False, + "is_user_admin": False, "user_roles": ["Student"], - "edit_reasons": [{"code": "test-edit-reason", "label": "Test Edit Reason"}], - "post_close_reasons": [{"code": "test-close-reason", "label": "Test Close Reason"}], - 'show_discussions': True, - 'is_notify_all_learners_enabled': False, - 'captcha_settings': { - 'enabled': False, - 'site_key': None, + "edit_reasons": [ + {"code": "test-edit-reason", "label": "Test Edit Reason"} + ], + "post_close_reasons": [ + {"code": "test-close-reason", "label": "Test Close Reason"} + ], + "show_discussions": True, + "is_notify_all_learners_enabled": False, + "captcha_settings": { + "enabled": False, + "site_key": None, }, "is_email_verified": True, "only_verified_users_can_post": False, - "content_creation_rate_limited": False - } + "content_creation_rate_limited": False, + "enable_discussion_ban": False, + }, ) @@ -574,8 +662,10 @@ class RetireViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): def setUp(self): super().setUp() - RetirementState.objects.create(state_name='PENDING', state_execution_order=1) - self.retire_forums_state = RetirementState.objects.create(state_name='RETIRE_FORUMS', state_execution_order=11) + RetirementState.objects.create(state_name="PENDING", state_execution_order=1) + self.retire_forums_state = RetirementState.objects.create( + state_name="RETIRE_FORUMS", state_execution_order=11 + ) self.retirement = UserRetirementStatus.create_retirement(self.user) self.retirement.current_state = self.retire_forums_state @@ -586,8 +676,8 @@ def setUp(self): self.retired_username = get_retired_username_by_username(self.user.username) self.url = reverse("retire_discussion_user") patcher = mock.patch( - 'openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled', - return_value=False + "openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled", + return_value=False, ) patcher.start() self.addCleanup(patcher.stop) @@ -599,14 +689,14 @@ def assert_response_correct(self, response, expected_status, expected_content): assert response.status_code == expected_status if expected_content: - assert response.content.decode('utf-8') == expected_content + assert response.content.decode("utf-8") == expected_content def build_jwt_headers(self, user): """ Helper function for creating headers for the JWT authentication. """ token = create_jwt_for_user(user) - headers = {'HTTP_AUTHORIZATION': 'JWT ' + token} + headers = {"HTTP_AUTHORIZATION": "JWT " + token} return headers def test_basic(self): @@ -615,7 +705,7 @@ def test_basic(self): """ self.register_get_user_retire_response(self.user) headers = self.build_jwt_headers(self.superuser) - data = {'username': self.user.username} + data = {"username": self.user.username} response = self.superuser_client.post(self.url, data, **headers) self.assert_response_correct(response, 204, b"") @@ -623,9 +713,11 @@ def test_downstream_forums_error(self): """ Check that we bubble up errors from the comments service """ - self.register_get_user_retire_response(self.user, status=500, body="Server error") + self.register_get_user_retire_response( + self.user, status=500, body="Server error" + ) headers = self.build_jwt_headers(self.superuser) - data = {'username': self.user.username} + data = {"username": self.user.username} response = self.superuser_client.post(self.url, data, **headers) self.assert_response_correct(response, 500, '"Server error"') @@ -635,7 +727,7 @@ def test_nonexistent_user(self): """ nonexistent_username = "nonexistent user" self.retired_username = get_retired_username_by_username(nonexistent_username) - data = {'username': nonexistent_username} + data = {"username": nonexistent_username} headers = self.build_jwt_headers(self.superuser) response = self.superuser_client.post(self.url, data, **headers) self.assert_response_correct(response, 404, None) @@ -649,7 +741,10 @@ def test_not_authenticated(self): @ddt.ddt @httpretty.activate -@mock.patch('django.conf.settings.USERNAME_REPLACEMENT_WORKER', 'test_replace_username_service_worker') +@mock.patch( + "django.conf.settings.USERNAME_REPLACEMENT_WORKER", + "test_replace_username_service_worker", +) @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) class ReplaceUsernamesViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): """Tests for ReplaceUsernamesView""" @@ -662,8 +757,8 @@ def setUp(self): self.new_username = "test_username_replacement" self.url = reverse("replace_discussion_username") patcher = mock.patch( - 'openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled', - return_value=False + "openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled", + return_value=False, ) patcher.start() self.addCleanup(patcher.stop) @@ -682,34 +777,28 @@ def build_jwt_headers(self, user): Helper function for creating headers for the JWT authentication. """ token = create_jwt_for_user(user) - headers = {'HTTP_AUTHORIZATION': 'JWT ' + token} + headers = {"HTTP_AUTHORIZATION": "JWT " + token} return headers def call_api(self, user, client, data): - """ Helper function to call API with data """ + """Helper function to call API with data""" data = json.dumps(data) headers = self.build_jwt_headers(user) - return client.post(self.url, data, content_type='application/json', **headers) + return client.post(self.url, data, content_type="application/json", **headers) - @ddt.data( - [{}, {}], - {}, - [{"test_key": "test_value", "test_key_2": "test_value_2"}] - ) + @ddt.data([{}, {}], {}, [{"test_key": "test_value", "test_key_2": "test_value_2"}]) def test_bad_schema(self, mapping_data): - """ Verify the endpoint rejects bad data schema """ - data = { - "username_mappings": mapping_data - } + """Verify the endpoint rejects bad data schema""" + data = {"username_mappings": mapping_data} response = self.call_api(self.worker, self.worker_client, data) assert response.status_code == 400 def test_auth(self): - """ Verify the endpoint only works with the service worker """ + """Verify the endpoint only works with the service worker""" data = { "username_mappings": [ {"test_username_1": "test_new_username_1"}, - {"test_username_2": "test_new_username_2"} + {"test_username_2": "test_new_username_2"}, ] } @@ -727,15 +816,15 @@ def test_auth(self): assert response.status_code == 200 def test_basic(self): - """ Check successful replacement """ + """Check successful replacement""" data = { "username_mappings": [ {self.user.username: self.new_username}, ] } expected_response = { - 'failed_replacements': [], - 'successful_replacements': data["username_mappings"] + "failed_replacements": [], + "successful_replacements": data["username_mappings"], } self.register_get_username_replacement_response(self.user) response = self.call_api(self.worker, self.worker_client, data) @@ -751,7 +840,9 @@ def test_not_authenticated(self): @ddt.ddt @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) -class CourseTopicsViewTest(DiscussionAPIViewTestMixin, CommentsServiceMockMixin, ModuleStoreTestCase): +class CourseTopicsViewTest( + DiscussionAPIViewTestMixin, CommentsServiceMockMixin, ModuleStoreTestCase +): """ Tests for CourseTopicsView """ @@ -768,10 +859,12 @@ def setUp(self): "courseware-2": {"discussion": 4, "question": 5}, "courseware-3": {"discussion": 7, "question": 2}, } - self.register_get_course_commentable_counts_response(self.course.id, self.thread_counts_map) + self.register_get_course_commentable_counts_response( + self.course.id, self.thread_counts_map + ) patcher = mock.patch( - 'openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled', - return_value=False + "openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled", + return_value=False, ) patcher.start() self.addCleanup(patcher.stop) @@ -786,7 +879,7 @@ def create_course(self, blocks_count, module_store, topics): run="c", start=datetime.now(UTC), default_store=module_store, - discussion_topics=topics + discussion_topics=topics, ) CourseEnrollmentFactory.create(user=self.user, course_id=course.id) course_url = reverse("course_topics", kwargs={"course_id": str(course.id)}) @@ -794,10 +887,10 @@ def create_course(self, blocks_count, module_store, topics): for i in range(blocks_count): BlockFactory.create( parent_location=course.location, - category='discussion', - discussion_id=f'id_module_{i}', - discussion_category=f'Category {i}', - discussion_target=f'Discussion {i}', + category="discussion", + discussion_id=f"id_module_{i}", + discussion_category=f"Category {i}", + discussion_target=f"Discussion {i}", publish_item=False, ) return course_url, course.id @@ -812,7 +905,7 @@ def make_discussion_xblock(self, topic_id, category, subcategory, **kwargs): discussion_id=topic_id, discussion_category=category, discussion_target=subcategory, - **kwargs + **kwargs, ) def test_404(self): @@ -820,9 +913,7 @@ def test_404(self): reverse("course_topics", kwargs={"course_id": "non/existent/course"}) ) self.assert_response_correct( - response, - 404, - {"developer_message": "Course not found."} + response, 404, {"developer_message": "Course not found."} ) def test_basic(self): @@ -832,21 +923,30 @@ def test_basic(self): 200, { "courseware_topics": [], - "non_courseware_topics": [{ - "id": "test_topic", - "name": "Test Topic", - "children": [], - "thread_list_url": 'http://testserver/api/discussion/v1/threads/' - '?course_id=course-v1%3Ax%2By%2Bz&topic_id=test_topic', - "thread_counts": {"discussion": 0, "question": 0}, - }], - } + "non_courseware_topics": [ + { + "id": "test_topic", + "name": "Test Topic", + "children": [], + "thread_list_url": "http://testserver/api/discussion/v1/threads/" + "?course_id=course-v1%3Ax%2By%2Bz&topic_id=test_topic", + "thread_counts": {"discussion": 0, "question": 0}, + } + ], + }, ) @ddt.data( (2, ModuleStoreEnum.Type.split, 2, {"Test Topic 1": {"id": "test_topic_1"}}), - (2, ModuleStoreEnum.Type.split, 2, - {"Test Topic 1": {"id": "test_topic_1"}, "Test Topic 2": {"id": "test_topic_2"}}), + ( + 2, + ModuleStoreEnum.Type.split, + 2, + { + "Test Topic 1": {"id": "test_topic_1"}, + "Test Topic 2": {"id": "test_topic_2"}, + }, + ), (10, ModuleStoreEnum.Type.split, 2, {"Test Topic 1": {"id": "test_topic_1"}}), ) @ddt.unpack @@ -868,7 +968,7 @@ def test_discussion_topic_404(self): self.assert_response_correct( response, 404, - {"developer_message": "Discussion not found for 'invalid_topic_id'."} + {"developer_message": "Discussion not found for 'invalid_topic_id'."}, ) def test_topic_id(self): @@ -888,38 +988,41 @@ def test_topic_id(self): "non_courseware_topics": [], "courseware_topics": [ { - "children": [{ - "children": [], - "id": "topic_id_1", - "thread_list_url": "http://testserver/api/discussion/v1/threads/?" - "course_id=course-v1%3Ax%2By%2Bz&topic_id=topic_id_1", - "name": "test_target_1", - "thread_counts": {"discussion": 0, "question": 0}, - }], + "children": [ + { + "children": [], + "id": "topic_id_1", + "thread_list_url": "http://testserver/api/discussion/v1/threads/?" + "course_id=course-v1%3Ax%2By%2Bz&topic_id=topic_id_1", + "name": "test_target_1", + "thread_counts": {"discussion": 0, "question": 0}, + } + ], "id": None, "thread_list_url": "http://testserver/api/discussion/v1/threads/?" - "course_id=course-v1%3Ax%2By%2Bz&topic_id=topic_id_1", + "course_id=course-v1%3Ax%2By%2Bz&topic_id=topic_id_1", "name": "test_category_1", "thread_counts": None, }, { - "children": - [{ + "children": [ + { "children": [], "id": "topic_id_2", "thread_list_url": "http://testserver/api/discussion/v1/threads/?" - "course_id=course-v1%3Ax%2By%2Bz&topic_id=topic_id_2", + "course_id=course-v1%3Ax%2By%2Bz&topic_id=topic_id_2", "name": "test_target_2", "thread_counts": {"discussion": 0, "question": 0}, - }], + } + ], "id": None, "thread_list_url": "http://testserver/api/discussion/v1/threads/?" - "course_id=course-v1%3Ax%2By%2Bz&topic_id=topic_id_2", + "course_id=course-v1%3Ax%2By%2Bz&topic_id=topic_id_2", "name": "test_category_2", "thread_counts": None, - } - ] - } + }, + ], + }, ) @override_waffle_flag(ENABLE_NEW_STRUCTURE_DISCUSSIONS, True) @@ -930,45 +1033,46 @@ def test_new_course_structure_response(self): """ chapter = BlockFactory.create( parent_location=self.course.location, - category='chapter', + category="chapter", display_name="Week 1", start=datetime(2015, 3, 1, tzinfo=UTC), ) sequential = BlockFactory.create( parent_location=chapter.location, - category='sequential', + category="sequential", display_name="Lesson 1", start=datetime(2015, 3, 1, tzinfo=UTC), ) BlockFactory.create( parent_location=sequential.location, - category='vertical', - display_name='vertical', + category="vertical", + display_name="vertical", start=datetime(2015, 4, 1, tzinfo=UTC), ) DiscussionsConfiguration.objects.create( - context_key=self.course.id, - provider_type=Provider.OPEN_EDX + context_key=self.course.id, provider_type=Provider.OPEN_EDX ) update_discussions_settings_from_course_task(str(self.course.id)) response = json.loads(self.client.get(self.url).content.decode()) - keys = ['children', 'id', 'name', 'thread_counts', 'thread_list_url'] - assert list(response.keys()) == ['courseware_topics', 'non_courseware_topics'] - assert len(response['courseware_topics']) == 1 - courseware_keys = list(response['courseware_topics'][0].keys()) + keys = ["children", "id", "name", "thread_counts", "thread_list_url"] + assert list(response.keys()) == ["courseware_topics", "non_courseware_topics"] + assert len(response["courseware_topics"]) == 1 + courseware_keys = list(response["courseware_topics"][0].keys()) courseware_keys.sort() assert courseware_keys == keys - assert len(response['non_courseware_topics']) == 1 - non_courseware_keys = list(response['non_courseware_topics'][0].keys()) + assert len(response["non_courseware_topics"]) == 1 + non_courseware_keys = list(response["non_courseware_topics"][0].keys()) non_courseware_keys.sort() assert non_courseware_keys == keys @ddt.ddt -@mock.patch('lms.djangoapps.discussion.rest_api.api._get_course', mock.Mock()) +@mock.patch("lms.djangoapps.discussion.rest_api.api._get_course", mock.Mock()) @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) @override_waffle_flag(ENABLE_NEW_STRUCTURE_DISCUSSIONS, True) -class CourseTopicsViewV3Test(DiscussionAPIViewTestMixin, CommentsServiceMockMixin, ModuleStoreTestCase): +class CourseTopicsViewV3Test( + DiscussionAPIViewTestMixin, CommentsServiceMockMixin, ModuleStoreTestCase +): """ Tests for CourseTopicsViewV3 """ @@ -984,55 +1088,68 @@ def setUp(self) -> None: end=datetime(2028, 1, 1), enrollment_start=datetime(2020, 1, 1), enrollment_end=datetime(2028, 1, 1), - discussion_topics={"Course Wide Topic": { - "id": 'course-wide-topic', - "usage_key": None, - }} + discussion_topics={ + "Course Wide Topic": { + "id": "course-wide-topic", + "usage_key": None, + } + }, ) self.chapter = BlockFactory.create( parent_location=self.course.location, - category='chapter', + category="chapter", display_name="Week 1", start=datetime(2015, 3, 1, tzinfo=UTC), ) self.sequential = BlockFactory.create( parent_location=self.chapter.location, - category='sequential', + category="sequential", display_name="Lesson 1", start=datetime(2015, 3, 1, tzinfo=UTC), ) self.verticals = [ BlockFactory.create( parent_location=self.sequential.location, - category='vertical', - display_name='vertical', + category="vertical", + display_name="vertical", start=datetime(2015, 4, 1, tzinfo=UTC), ) ] course_key = self.course.id - self.config = DiscussionsConfiguration.objects.create(context_key=course_key, provider_type=Provider.OPEN_EDX) + self.config = DiscussionsConfiguration.objects.create( + context_key=course_key, provider_type=Provider.OPEN_EDX + ) topic_links = [] update_discussions_settings_from_course_task(str(course_key)) - topic_id_query = DiscussionTopicLink.objects.filter(context_key=course_key).values_list( - 'external_id', flat=True, + topic_id_query = DiscussionTopicLink.objects.filter( + context_key=course_key + ).values_list( + "external_id", + flat=True, ) - topic_ids = list(topic_id_query.order_by('ordering')) + topic_ids = list(topic_id_query.order_by("ordering")) DiscussionTopicLink.objects.bulk_create(topic_links) self.topic_stats = { - **{topic_id: dict(discussion=random.randint(0, 10), question=random.randint(0, 10)) - for topic_id in set(topic_ids)}, + **{ + topic_id: dict( + discussion=random.randint(0, 10), question=random.randint(0, 10) + ) + for topic_id in set(topic_ids) + }, topic_ids[0]: dict(discussion=0, question=0), } patcher = mock.patch( - 'lms.djangoapps.discussion.rest_api.api.get_course_commentable_counts', + "lms.djangoapps.discussion.rest_api.api.get_course_commentable_counts", mock.Mock(return_value=self.topic_stats), ) patcher.start() self.addCleanup(patcher.stop) - self.url = reverse("course_topics_v3", kwargs={"course_id": str(self.course.id)}) + self.url = reverse( + "course_topics_v3", kwargs={"course_id": str(self.course.id)} + ) patcher = mock.patch( - 'openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled', - return_value=False + "openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled", + return_value=False, ) patcher.start() self.addCleanup(patcher.stop) @@ -1041,12 +1158,23 @@ def test_basic(self): response = self.client.get(self.url) data = json.loads(response.content.decode()) expected_non_courseware_keys = [ - 'id', 'usage_key', 'name', 'thread_counts', 'enabled_in_context', - 'courseware' + "id", + "usage_key", + "name", + "thread_counts", + "enabled_in_context", + "courseware", ] expected_courseware_keys = [ - 'id', 'block_id', 'lms_web_url', 'legacy_web_url', 'student_view_url', - 'type', 'display_name', 'children', 'courseware' + "id", + "block_id", + "lms_web_url", + "legacy_web_url", + "student_view_url", + "type", + "display_name", + "children", + "courseware", ] assert response.status_code == 200 assert len(data) == 2 @@ -1054,11 +1182,11 @@ def test_basic(self): assert non_courseware_topic_keys == expected_non_courseware_keys courseware_topic_keys = list(data[1].keys()) assert courseware_topic_keys == expected_courseware_keys - expected_courseware_keys.remove('courseware') - sequential_keys = list(data[1]['children'][0].keys()) - assert sequential_keys == (expected_courseware_keys + ['thread_counts']) - expected_non_courseware_keys.remove('courseware') - vertical_keys = list(data[1]['children'][0]['children'][0].keys()) + expected_courseware_keys.remove("courseware") + sequential_keys = list(data[1]["children"][0].keys()) + assert sequential_keys == (expected_courseware_keys + ["thread_counts"]) + expected_non_courseware_keys.remove("courseware") + vertical_keys = list(data[1]["children"][0]["children"][0].keys()) assert vertical_keys == expected_non_courseware_keys @@ -1095,53 +1223,70 @@ def setUp(self): {"key": "author", "value": self.author.username}, {"key": "abuse_flagged", "value": False}, {"key": "author_label", "value": None}, + {"key": "is_author_banned", "value": False}, + {"key": "author_ban_scope", "value": None}, {"key": "can_delete", "value": True}, {"key": "close_reason", "value": None}, { "key": "comment_list_url", - "value": "http://testserver/api/discussion/v1/comments/?thread_id=test_thread" + "value": "http://testserver/api/discussion/v1/comments/?thread_id=test_thread", }, { "key": "editable_fields", "value": [ - 'abuse_flagged', 'anonymous', 'copy_link', 'following', 'raw_body', - 'read', 'title', 'topic_id', 'type' - ] + "abuse_flagged", + "anonymous", + "copy_link", + "following", + "raw_body", + "read", + "title", + "topic_id", + "type", + ], }, {"key": "endorsed_comment_list_url", "value": None}, {"key": "following", "value": False}, {"key": "group_name", "value": None}, {"key": "has_endorsed", "value": False}, {"key": "last_edit", "value": None}, + {"key": "learner_status", "value": "new"}, {"key": "non_endorsed_comment_list_url", "value": None}, {"key": "preview_body", "value": "Test body"}, {"key": "raw_body", "value": "Test body"}, - {"key": "rendered_body", "value": "

Test body

"}, {"key": "response_count", "value": 0}, {"key": "topic_id", "value": "test_topic"}, {"key": "type", "value": "discussion"}, - {"key": "users", "value": { - self.user.username: { - "profile": { - "image": { - "has_image": False, - "image_url_full": "http://testserver/static/default_500.png", - "image_url_large": "http://testserver/static/default_120.png", - "image_url_medium": "http://testserver/static/default_50.png", - "image_url_small": "http://testserver/static/default_30.png", + { + "key": "users", + "value": { + self.user.username: { + "profile": { + "image": { + "has_image": False, + "image_url_full": "http://testserver/static/default_500.png", + "image_url_large": "http://testserver/static/default_120.png", + "image_url_medium": "http://testserver/static/default_50.png", + "image_url_small": "http://testserver/static/default_30.png", + } } } - } - }}, + }, + }, {"key": "vote_count", "value": 4}, {"key": "voted", "value": False}, - + {"key": "is_deleted", "value": False}, + {"key": "deleted_at", "value": None}, + {"key": "deleted_by", "value": None}, + {"key": "deleted_by_label", "value": None}, ] - self.url = reverse("discussion_learner_threads", kwargs={'course_id': str(self.course.id)}) + self.url = reverse( + "discussion_learner_threads", kwargs={"course_id": str(self.course.id)} + ) patcher = mock.patch( - 'openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled', - return_value=False + "openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled", + return_value=False, ) patcher.start() self.addCleanup(patcher.stop) @@ -1152,12 +1297,12 @@ def update_thread(self, thread): Value of these keys has been defined in setUp function """ for element in self.add_keys: - thread[element['key']] = element['value'] + thread[element["key"]] = element["value"] for pair in self.replace_keys: - thread[pair['to']] = thread.pop(pair['from']) + thread[pair["to"]] = thread.pop(pair["from"]) for key in self.remove_keys: thread.pop(key) - thread['comment_count'] += 1 + thread["comment_count"] += 1 return thread def test_basic(self): @@ -1169,22 +1314,26 @@ def test_basic(self): """ self.register_get_user_response(self.user) expected_cs_comments_response = { - "collection": [make_minimal_cs_thread({ - "id": "test_thread", - "course_id": str(self.course.id), - "commentable_id": "test_topic", - "user_id": str(self.user.id), - "username": self.user.username, - "created_at": "2015-04-28T00:00:00Z", - "updated_at": "2015-04-28T11:11:11Z", - "title": "Test Title", - "body": "Test body", - "votes": {"up_count": 4}, - "comments_count": 5, - "unread_comments_count": 3, - "closed_by_label": None, - "edit_by_label": None, - })], + "collection": [ + make_minimal_cs_thread( + { + "id": "test_thread", + "course_id": str(self.course.id), + "commentable_id": "test_topic", + "user_id": str(self.user.id), + "username": self.user.username, + "created_at": "2015-04-28T00:00:00Z", + "updated_at": "2015-04-28T11:11:11Z", + "title": "Test Title", + "body": "Test body", + "votes": {"up_count": 4}, + "comments_count": 5, + "unread_comments_count": 3, + "closed_by_label": None, + "edit_by_label": None, + } + ) + ], "page": 1, "num_pages": 1, } @@ -1192,14 +1341,14 @@ def test_basic(self): self.url += f"?username={self.user.username}" response = self.client.get(self.url) assert response.status_code == 200 - response_data = json.loads(response.content.decode('utf-8')) - expected_api_response = expected_cs_comments_response['collection'] + response_data = json.loads(response.content.decode("utf-8")) + expected_api_response = expected_cs_comments_response["collection"] for thread in expected_api_response: self.update_thread(thread) - assert response_data['results'] == expected_api_response - assert response_data['pagination'] == { + assert response_data["results"] == expected_api_response + assert response_data["pagination"] == { "next": None, "previous": None, "count": 1, @@ -1229,20 +1378,24 @@ def test_thread_type_by(self, thread_type): thread_type (str): Value of thread_type can be 'None', 'discussion' and 'question' """ - threads = [make_minimal_cs_thread({ - "id": "test_thread", - "course_id": str(self.course.id), - "commentable_id": "test_topic", - "user_id": str(self.user.id), - "username": self.user.username, - "created_at": "2015-04-28T00:00:00Z", - "updated_at": "2015-04-28T11:11:11Z", - "title": "Test Title", - "body": "Test body", - "votes": {"up_count": 4}, - "comments_count": 5, - "unread_comments_count": 3, - })] + threads = [ + make_minimal_cs_thread( + { + "id": "test_thread", + "course_id": str(self.course.id), + "commentable_id": "test_topic", + "user_id": str(self.user.id), + "username": self.user.username, + "created_at": "2015-04-28T00:00:00Z", + "updated_at": "2015-04-28T11:11:11Z", + "title": "Test Title", + "body": "Test body", + "votes": {"up_count": 4}, + "comments_count": 5, + "unread_comments_count": 3, + } + ) + ] expected_cs_comments_response = { "collection": threads, "page": 1, @@ -1256,23 +1409,26 @@ def test_thread_type_by(self, thread_type): "course_id": str(self.course.id), "username": self.user.username, "thread_type": thread_type, - } + }, ) assert response.status_code == 200 - self.assert_last_query_params({ - "user_id": [str(self.user.id)], - "course_id": [str(self.course.id)], - "page": ["1"], - "per_page": ["10"], - "thread_type": [thread_type], - "sort_key": ['activity'], - "count_flagged": ["False"] - }) + self.assert_last_query_params( + { + "user_id": [str(self.user.id)], + "course_id": [str(self.course.id)], + "page": ["1"], + "per_page": ["10"], + "thread_type": [thread_type], + "sort_key": ["activity"], + "count_flagged": ["False"], + "show_deleted": ["False"], + } + ) @ddt.data( ("last_activity_at", "activity"), ("comment_count", "comments"), - ("vote_count", "votes") + ("vote_count", "votes"), ) @ddt.unpack def test_order_by(self, http_query, cc_query): @@ -1283,20 +1439,24 @@ def test_order_by(self, http_query, cc_query): http_query (str): Query string sent in the http request cc_query (str): Query string used for the comments client service """ - threads = [make_minimal_cs_thread({ - "id": "test_thread", - "course_id": str(self.course.id), - "commentable_id": "test_topic", - "user_id": str(self.user.id), - "username": self.user.username, - "created_at": "2015-04-28T00:00:00Z", - "updated_at": "2015-04-28T11:11:11Z", - "title": "Test Title", - "body": "Test body", - "votes": {"up_count": 4}, - "comments_count": 5, - "unread_comments_count": 3, - })] + threads = [ + make_minimal_cs_thread( + { + "id": "test_thread", + "course_id": str(self.course.id), + "commentable_id": "test_topic", + "user_id": str(self.user.id), + "username": self.user.username, + "created_at": "2015-04-28T00:00:00Z", + "updated_at": "2015-04-28T11:11:11Z", + "title": "Test Title", + "body": "Test body", + "votes": {"up_count": 4}, + "comments_count": 5, + "unread_comments_count": 3, + } + ) + ] expected_cs_comments_response = { "collection": threads, "page": 1, @@ -1310,17 +1470,20 @@ def test_order_by(self, http_query, cc_query): "course_id": str(self.course.id), "username": self.user.username, "order_by": http_query, - } + }, ) assert response.status_code == 200 - self.assert_last_query_params({ - "user_id": [str(self.user.id)], - "course_id": [str(self.course.id)], - "page": ["1"], - "per_page": ["10"], - "sort_key": [cc_query], - "count_flagged": ["False"] - }) + self.assert_last_query_params( + { + "user_id": [str(self.user.id)], + "course_id": [str(self.course.id)], + "page": ["1"], + "per_page": ["10"], + "sort_key": [cc_query], + "count_flagged": ["False"], + "show_deleted": ["False"], + } + ) @ddt.data("flagged", "unanswered", "unread", "unresponded") def test_status_by(self, post_status): @@ -1331,20 +1494,24 @@ def test_status_by(self, post_status): post_status (str): Value of post_status can be 'flagged', 'unanswered' and 'unread' """ - threads = [make_minimal_cs_thread({ - "id": "test_thread", - "course_id": str(self.course.id), - "commentable_id": "test_topic", - "user_id": str(self.user.id), - "username": self.user.username, - "created_at": "2015-04-28T00:00:00Z", - "updated_at": "2015-04-28T11:11:11Z", - "title": "Test Title", - "body": "Test body", - "votes": {"up_count": 4}, - "comments_count": 5, - "unread_comments_count": 3, - })] + threads = [ + make_minimal_cs_thread( + { + "id": "test_thread", + "course_id": str(self.course.id), + "commentable_id": "test_topic", + "user_id": str(self.user.id), + "username": self.user.username, + "created_at": "2015-04-28T00:00:00Z", + "updated_at": "2015-04-28T11:11:11Z", + "title": "Test Title", + "body": "Test body", + "votes": {"up_count": 4}, + "comments_count": 5, + "unread_comments_count": 3, + } + ) + ] expected_cs_comments_response = { "collection": threads, "page": 1, @@ -1358,29 +1525,37 @@ def test_status_by(self, post_status): "course_id": str(self.course.id), "username": self.user.username, "status": post_status, - } + }, ) if post_status == "flagged": assert response.status_code == 403 else: assert response.status_code == 200 - self.assert_last_query_params({ - "user_id": [str(self.user.id)], - "course_id": [str(self.course.id)], - "page": ["1"], - "per_page": ["10"], - post_status: ['True'], - "sort_key": ['activity'], - "count_flagged": ["False"] - }) + self.assert_last_query_params( + { + "user_id": [str(self.user.id)], + "course_id": [str(self.course.id)], + "page": ["1"], + "per_page": ["10"], + post_status: ["True"], + "sort_key": ["activity"], + "count_flagged": ["False"], + "show_deleted": ["False"], + } + ) @ddt.ddt -class CourseDiscussionSettingsAPIViewTest(APITestCase, UrlResetMixin, ModuleStoreTestCase): +class CourseDiscussionSettingsAPIViewTest( + APITestCase, UrlResetMixin, ModuleStoreTestCase +): """ Test the course discussion settings handler API endpoint. """ - @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) + + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) def setUp(self): super().setUp() self.course = CourseFactory.create( @@ -1388,24 +1563,26 @@ def setUp(self): course="y", run="z", start=datetime.now(UTC), - discussion_topics={"Test Topic": {"id": "test_topic"}} + discussion_topics={"Test Topic": {"id": "test_topic"}}, + ) + self.path = reverse( + "discussion_course_settings", kwargs={"course_id": str(self.course.id)} ) - self.path = reverse('discussion_course_settings', kwargs={'course_id': str(self.course.id)}) self.password = self.TEST_PASSWORD - self.user = UserFactory(username='staff', password=self.password, is_staff=True) + self.user = UserFactory(username="staff", password=self.password, is_staff=True) patcher = mock.patch( - 'openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled', - return_value=False + "openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled", + return_value=False, ) patcher.start() self.addCleanup(patcher.stop) def _get_oauth_headers(self, user): """Return the OAuth headers for testing OAuth authentication""" - access_token = AccessTokenFactory.create(user=user, application=ApplicationFactory()).token - headers = { - 'HTTP_AUTHORIZATION': 'Bearer ' + access_token - } + access_token = AccessTokenFactory.create( + user=user, application=ApplicationFactory() + ).token + headers = {"HTTP_AUTHORIZATION": "Bearer " + access_token} return headers def _login_as_staff(self): @@ -1413,24 +1590,30 @@ def _login_as_staff(self): self.client.login(username=self.user.username, password=self.password) def _login_as_discussion_staff(self): - user = UserFactory(username='abc', password='abc') - role = Role.objects.create(name='Administrator', course_id=self.course.id) + user = UserFactory(username="abc", password="abc") + role = Role.objects.create(name="Administrator", course_id=self.course.id) role.users.set([user]) - self.client.login(username=user.username, password='abc') + self.client.login(username=user.username, password="abc") def _create_divided_discussions(self): """Create some divided discussions for testing.""" - divided_inline_discussions = ['Topic A', ] - divided_course_wide_discussions = ['Topic B', ] - divided_discussions = divided_inline_discussions + divided_course_wide_discussions + divided_inline_discussions = [ + "Topic A", + ] + divided_course_wide_discussions = [ + "Topic B", + ] + divided_discussions = ( + divided_inline_discussions + divided_course_wide_discussions + ) BlockFactory.create( parent=self.course, - category='discussion', - discussion_id=topic_name_to_id(self.course, 'Topic A'), - discussion_category='Chapter', - discussion_target='Discussion', - start=datetime.now() + category="discussion", + discussion_id=topic_name_to_id(self.course, "Topic A"), + discussion_category="Chapter", + discussion_target="Discussion", + start=datetime.now(), ) discussion_topics = { "Topic B": {"id": "Topic B"}, @@ -1439,31 +1622,36 @@ def _create_divided_discussions(self): config_course_discussions( self.course, discussion_topics=discussion_topics, - divided_discussions=divided_discussions + divided_discussions=divided_discussions, ) return divided_inline_discussions, divided_course_wide_discussions def _get_expected_response(self): """Return the default expected response before any changes to the discussion settings.""" return { - 'always_divide_inline_discussions': False, - 'divided_inline_discussions': [], - 'divided_course_wide_discussions': [], - 'id': 1, - 'division_scheme': 'cohort', - 'available_division_schemes': ['cohort'], - 'reported_content_email_notifications': False, + "always_divide_inline_discussions": False, + "divided_inline_discussions": [], + "divided_course_wide_discussions": [], + "id": 1, + "division_scheme": "cohort", + "available_division_schemes": ["cohort"], + "reported_content_email_notifications": False, } def patch_request(self, data, headers=None): headers = headers if headers else {} - return self.client.patch(self.path, json.dumps(data), content_type='application/merge-patch+json', **headers) + return self.client.patch( + self.path, + json.dumps(data), + content_type="application/merge-patch+json", + **headers, + ) def _assert_current_settings(self, expected_response): """Validate the current discussion settings against the expected response.""" response = self.client.get(self.path) assert response.status_code == 200 - content = json.loads(response.content.decode('utf-8')) + content = json.loads(response.content.decode("utf-8")) assert content == expected_response def _assert_patched_settings(self, data, expected_response): @@ -1472,7 +1660,7 @@ def _assert_patched_settings(self, data, expected_response): assert response.status_code == 204 self._assert_current_settings(expected_response) - @ddt.data('get', 'patch') + @ddt.data("get", "patch") def test_authentication_required(self, method): """Test and verify that authentication is required for this endpoint.""" self.client.logout() @@ -1480,8 +1668,8 @@ def test_authentication_required(self, method): assert response.status_code == 401 @ddt.data( - {'is_staff': False, 'get_status': 403, 'put_status': 403}, - {'is_staff': True, 'get_status': 200, 'put_status': 204}, + {"is_staff": False, "get_status": 403, "put_status": 403}, + {"is_staff": True, "get_status": 200, "put_status": 204}, ) @ddt.unpack def test_oauth(self, is_staff, get_status, put_status): @@ -1494,7 +1682,7 @@ def test_oauth(self, is_staff, get_status, put_status): assert response.status_code == get_status response = self.patch_request( - {'always_divide_inline_discussions': True}, headers + {"always_divide_inline_discussions": True}, headers ) assert response.status_code == put_status @@ -1502,66 +1690,68 @@ def test_non_existent_course_id(self): """Test the response when this endpoint is passed a non-existent course id.""" self._login_as_staff() response = self.client.get( - reverse('discussion_course_settings', kwargs={ - 'course_id': 'course-v1:a+b+c' - }) + reverse( + "discussion_course_settings", kwargs={"course_id": "course-v1:a+b+c"} + ) ) assert response.status_code == 404 def test_patch_request_by_discussion_staff(self): """Test the response when patch request is sent by a user with discussions staff role.""" self._login_as_discussion_staff() - response = self.patch_request( - {'always_divide_inline_discussions': True} - ) + response = self.patch_request({"always_divide_inline_discussions": True}) assert response.status_code == 403 def test_get_request_by_discussion_staff(self): """Test the response when get request is sent by a user with discussions staff role.""" self._login_as_discussion_staff() - divided_inline_discussions, divided_course_wide_discussions = self._create_divided_discussions() + divided_inline_discussions, divided_course_wide_discussions = ( + self._create_divided_discussions() + ) response = self.client.get(self.path) assert response.status_code == 200 expected_response = self._get_expected_response() - expected_response['divided_course_wide_discussions'] = [ - topic_name_to_id(self.course, name) for name in divided_course_wide_discussions + expected_response["divided_course_wide_discussions"] = [ + topic_name_to_id(self.course, name) + for name in divided_course_wide_discussions ] - expected_response['divided_inline_discussions'] = [ + expected_response["divided_inline_discussions"] = [ topic_name_to_id(self.course, name) for name in divided_inline_discussions ] - content = json.loads(response.content.decode('utf-8')) + content = json.loads(response.content.decode("utf-8")) assert content == expected_response def test_get_request_by_non_staff_user(self): """Test the response when get request is sent by a regular user with no staff role.""" - user = UserFactory(username='abc', password='abc') - self.client.login(username=user.username, password='abc') + user = UserFactory(username="abc", password="abc") + self.client.login(username=user.username, password="abc") response = self.client.get(self.path) assert response.status_code == 403 def test_patch_request_by_non_staff_user(self): """Test the response when patch request is sent by a regular user with no staff role.""" - user = UserFactory(username='abc', password='abc') - self.client.login(username=user.username, password='abc') - response = self.patch_request( - {'always_divide_inline_discussions': True} - ) + user = UserFactory(username="abc", password="abc") + self.client.login(username=user.username, password="abc") + response = self.patch_request({"always_divide_inline_discussions": True}) assert response.status_code == 403 def test_get_settings(self): """Test the current discussion settings against the expected response.""" - divided_inline_discussions, divided_course_wide_discussions = self._create_divided_discussions() + divided_inline_discussions, divided_course_wide_discussions = ( + self._create_divided_discussions() + ) self._login_as_staff() response = self.client.get(self.path) assert response.status_code == 200 expected_response = self._get_expected_response() - expected_response['divided_course_wide_discussions'] = [ - topic_name_to_id(self.course, name) for name in divided_course_wide_discussions + expected_response["divided_course_wide_discussions"] = [ + topic_name_to_id(self.course, name) + for name in divided_course_wide_discussions ] - expected_response['divided_inline_discussions'] = [ + expected_response["divided_inline_discussions"] = [ topic_name_to_id(self.course, name) for name in divided_inline_discussions ] - content = json.loads(response.content.decode('utf-8')) + content = json.loads(response.content.decode("utf-8")) assert content == expected_response def test_available_schemes(self): @@ -1569,18 +1759,23 @@ def test_available_schemes(self): config_course_cohorts(self.course, is_cohorted=False) self._login_as_staff() expected_response = self._get_expected_response() - expected_response['available_division_schemes'] = [] + expected_response["available_division_schemes"] = [] self._assert_current_settings(expected_response) CourseModeFactory.create(course_id=self.course.id, mode_slug=CourseMode.AUDIT) - CourseModeFactory.create(course_id=self.course.id, mode_slug=CourseMode.VERIFIED) + CourseModeFactory.create( + course_id=self.course.id, mode_slug=CourseMode.VERIFIED + ) - expected_response['available_division_schemes'] = [CourseDiscussionSettings.ENROLLMENT_TRACK] + expected_response["available_division_schemes"] = [ + CourseDiscussionSettings.ENROLLMENT_TRACK + ] self._assert_current_settings(expected_response) config_course_cohorts(self.course, is_cohorted=True) - expected_response['available_division_schemes'] = [ - CourseDiscussionSettings.COHORT, CourseDiscussionSettings.ENROLLMENT_TRACK + expected_response["available_division_schemes"] = [ + CourseDiscussionSettings.COHORT, + CourseDiscussionSettings.ENROLLMENT_TRACK, ] self._assert_current_settings(expected_response) @@ -1594,11 +1789,11 @@ def test_empty_body_patch_request(self): assert response.status_code == 400 @ddt.data( - {'abc': 123}, - {'divided_course_wide_discussions': 3}, - {'divided_inline_discussions': 'a'}, - {'always_divide_inline_discussions': ['a']}, - {'division_scheme': True} + {"abc": 123}, + {"divided_course_wide_discussions": 3}, + {"divided_inline_discussions": "a"}, + {"always_divide_inline_discussions": ["a"]}, + {"division_scheme": True}, ) def test_invalid_body_parameters(self, body): """Test the response status code on sending a PATCH request with parameters having incorrect types.""" @@ -1612,31 +1807,34 @@ def test_update_always_divide_inline_discussion_settings(self): self._login_as_staff() expected_response = self._get_expected_response() self._assert_current_settings(expected_response) - expected_response['always_divide_inline_discussions'] = True + expected_response["always_divide_inline_discussions"] = True - self._assert_patched_settings({'always_divide_inline_discussions': True}, expected_response) + self._assert_patched_settings( + {"always_divide_inline_discussions": True}, expected_response + ) def test_update_course_wide_discussion_settings(self): """Test whether the 'divided_course_wide_discussions' setting is updated.""" - discussion_topics = { - 'Topic B': {'id': 'Topic B'} - } + discussion_topics = {"Topic B": {"id": "Topic B"}} config_course_cohorts(self.course, is_cohorted=True) config_course_discussions(self.course, discussion_topics=discussion_topics) expected_response = self._get_expected_response() self._login_as_staff() self._assert_current_settings(expected_response) - expected_response['divided_course_wide_discussions'] = [ + expected_response["divided_course_wide_discussions"] = [ topic_name_to_id(self.course, "Topic B") ] self._assert_patched_settings( - {'divided_course_wide_discussions': [topic_name_to_id(self.course, "Topic B")]}, - expected_response + { + "divided_course_wide_discussions": [ + topic_name_to_id(self.course, "Topic B") + ] + }, + expected_response, ) - expected_response['divided_course_wide_discussions'] = [] + expected_response["divided_course_wide_discussions"] = [] self._assert_patched_settings( - {'divided_course_wide_discussions': []}, - expected_response + {"divided_course_wide_discussions": []}, expected_response ) def test_update_inline_discussion_settings(self): @@ -1649,17 +1847,23 @@ def test_update_inline_discussion_settings(self): now = datetime.now() BlockFactory.create( parent_location=self.course.location, - category='discussion', - discussion_id='Topic_A', - discussion_category='Chapter', - discussion_target='Discussion', - start=now + category="discussion", + discussion_id="Topic_A", + discussion_category="Chapter", + discussion_target="Discussion", + start=now, + ) + expected_response["divided_inline_discussions"] = [ + "Topic_A", + ] + self._assert_patched_settings( + {"divided_inline_discussions": ["Topic_A"]}, expected_response ) - expected_response['divided_inline_discussions'] = ['Topic_A', ] - self._assert_patched_settings({'divided_inline_discussions': ['Topic_A']}, expected_response) - expected_response['divided_inline_discussions'] = [] - self._assert_patched_settings({'divided_inline_discussions': []}, expected_response) + expected_response["divided_inline_discussions"] = [] + self._assert_patched_settings( + {"divided_inline_discussions": []}, expected_response + ) def test_update_division_scheme(self): """Test whether the 'division_scheme' setting is updated.""" @@ -1667,15 +1871,17 @@ def test_update_division_scheme(self): self._login_as_staff() expected_response = self._get_expected_response() self._assert_current_settings(expected_response) - expected_response['division_scheme'] = 'none' - self._assert_patched_settings({'division_scheme': 'none'}, expected_response) + expected_response["division_scheme"] = "none" + self._assert_patched_settings({"division_scheme": "none"}, expected_response) def test_update_reported_content_email_notifications(self): """Test whether the 'reported_content_email_notifications' setting is updated.""" config_course_cohorts(self.course, is_cohorted=True) - config_course_discussions(self.course, reported_content_email_notifications=True) + config_course_discussions( + self.course, reported_content_email_notifications=True + ) expected_response = self._get_expected_response() - expected_response['reported_content_email_notifications'] = True + expected_response["reported_content_email_notifications"] = True self._login_as_staff() self._assert_current_settings(expected_response) @@ -1685,12 +1891,15 @@ class CourseDiscussionRolesAPIViewTest(APITestCase, UrlResetMixin, ModuleStoreTe """ Test the course discussion roles management endpoint. """ - @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) + + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) def setUp(self): super().setUp() patcher = mock.patch( - 'openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled', - return_value=False + "openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled", + return_value=False, ) patcher.start() self.addCleanup(patcher.stop) @@ -1701,26 +1910,27 @@ def setUp(self): start=datetime.now(UTC), ) self.password = self.TEST_PASSWORD - self.user = UserFactory(username='staff', password=self.password, is_staff=True) - course_key = CourseKey.from_string('course-v1:x+y+z') + self.user = UserFactory(username="staff", password=self.password, is_staff=True) + course_key = CourseKey.from_string("course-v1:x+y+z") seed_permissions_roles(course_key) - @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) def path(self, course_id=None, role=None): """Return the URL path to the endpoint based on the provided arguments.""" course_id = str(self.course.id) if course_id is None else course_id - role = 'Moderator' if role is None else role + role = "Moderator" if role is None else role return reverse( - 'discussion_course_roles', - kwargs={'course_id': course_id, 'rolename': role} + "discussion_course_roles", kwargs={"course_id": course_id, "rolename": role} ) def _get_oauth_headers(self, user): """Return the OAuth headers for testing OAuth authentication.""" - access_token = AccessTokenFactory.create(user=user, application=ApplicationFactory()).token - headers = { - 'HTTP_AUTHORIZATION': 'Bearer ' + access_token - } + access_token = AccessTokenFactory.create( + user=user, application=ApplicationFactory() + ).token + headers = {"HTTP_AUTHORIZATION": "Bearer " + access_token} return headers def _login_as_staff(self): @@ -1745,9 +1955,11 @@ def _add_users_to_role(self, users, rolename): def post(self, role, user_id, action): """Make a POST request to the endpoint using the provided parameters.""" self._login_as_staff() - return self.client.post(self.path(role=role), {'user_id': user_id, 'action': action}) + return self.client.post( + self.path(role=role), {"user_id": user_id, "action": action} + ) - @ddt.data('get', 'post') + @ddt.data("get", "post") def test_authentication_required(self, method): """Test and verify that authentication is required for this endpoint.""" self.client.logout() @@ -1760,29 +1972,31 @@ def test_oauth(self): self.client.logout() response = self.client.get(self.path(), **oauth_headers) assert response.status_code == 200 - body = {'user_id': 'staff', 'action': 'allow'} - response = self.client.post(self.path(), body, format='json', **oauth_headers) + body = {"user_id": "staff", "action": "allow"} + response = self.client.post(self.path(), body, format="json", **oauth_headers) assert response.status_code == 200 @ddt.data( - {'username': 'u1', 'is_staff': False, 'expected_status': 403}, - {'username': 'u2', 'is_staff': True, 'expected_status': 200}, + {"username": "u1", "is_staff": False, "expected_status": 403}, + {"username": "u2", "is_staff": True, "expected_status": 200}, ) @ddt.unpack def test_staff_permission_required(self, username, is_staff, expected_status): """Test and verify that only users with staff permission can access this endpoint.""" - UserFactory(username=username, password='edx', is_staff=is_staff) - self.client.login(username=username, password='edx') + UserFactory(username=username, password="edx", is_staff=is_staff) + self.client.login(username=username, password="edx") response = self.client.get(self.path()) assert response.status_code == expected_status - response = self.client.post(self.path(), {'user_id': username, 'action': 'allow'}, format='json') + response = self.client.post( + self.path(), {"user_id": username, "action": "allow"}, format="json" + ) assert response.status_code == expected_status def test_non_existent_course_id(self): """Test the response when the endpoint URL contains a non-existent course id.""" self._login_as_staff() - path = self.path(course_id='course-v1:a+b+c') + path = self.path(course_id="course-v1:a+b+c") response = self.client.get(path) assert response.status_code == 404 @@ -1793,7 +2007,7 @@ def test_non_existent_course_id(self): def test_non_existent_course_role(self): """Test the response when the endpoint URL contains a non-existent role.""" self._login_as_staff() - path = self.path(role='A') + path = self.path(role="A") response = self.client.get(path) assert response.status_code == 400 @@ -1802,10 +2016,10 @@ def test_non_existent_course_role(self): assert response.status_code == 400 @ddt.data( - {'role': 'Moderator', 'count': 0}, - {'role': 'Moderator', 'count': 1}, - {'role': 'Group Moderator', 'count': 2}, - {'role': 'Community TA', 'count': 3}, + {"role": "Moderator", "count": 0}, + {"role": "Moderator", "count": 1}, + {"role": "Group Moderator", "count": 2}, + {"role": "Community TA", "count": 3}, ) @ddt.unpack def test_get_role_members(self, role, count): @@ -1819,14 +2033,14 @@ def test_get_role_members(self, role, count): assert response.status_code == 200 - content = json.loads(response.content.decode('utf-8')) - assert content['course_id'] == 'course-v1:x+y+z' - assert len(content['results']) == count - expected_fields = ('username', 'email', 'first_name', 'last_name', 'group_name') - for item in content['results']: + content = json.loads(response.content.decode("utf-8")) + assert content["course_id"] == "course-v1:x+y+z" + assert len(content["results"]) == count + expected_fields = ("username", "email", "first_name", "last_name", "group_name") + for item in content["results"]: for expected_field in expected_fields: assert expected_field in item - assert content['division_scheme'] == 'cohort' + assert content["division_scheme"] == "cohort" def test_post_missing_body(self): """Test the response with a POST request without a body.""" @@ -1835,9 +2049,9 @@ def test_post_missing_body(self): assert response.status_code == 400 @ddt.data( - {'a': 1}, - {'user_id': 'xyz', 'action': 'allow'}, - {'user_id': 'staff', 'action': 123}, + {"a": 1}, + {"user_id": "xyz", "action": "allow"}, + {"user_id": "staff", "action": 123}, ) def test_missing_or_invalid_parameters(self, body): """ @@ -1848,82 +2062,100 @@ def test_missing_or_invalid_parameters(self, body): response = self.client.post(self.path(), body) assert response.status_code == 400 - response = self.client.post(self.path(), body, format='json') + response = self.client.post(self.path(), body, format="json") assert response.status_code == 400 @ddt.data( - {'action': 'allow', 'user_in_role': False}, - {'action': 'allow', 'user_in_role': True}, - {'action': 'revoke', 'user_in_role': False}, - {'action': 'revoke', 'user_in_role': True} + {"action": "allow", "user_in_role": False}, + {"action": "allow", "user_in_role": True}, + {"action": "revoke", "user_in_role": False}, + {"action": "revoke", "user_in_role": True}, ) @ddt.unpack def test_post_update_user_role(self, action, user_in_role): """Test the response when updating the user's role""" users = self._create_and_enroll_users(count=1) user = users[0] - role = 'Moderator' + role = "Moderator" if user_in_role: self._add_users_to_role(users, role) response = self.post(role, user.username, action) assert response.status_code == 200 - content = json.loads(response.content.decode('utf-8')) - assertion = self.assertTrue if action == 'allow' else self.assertFalse - assertion(any(user.username in x['username'] for x in content['results'])) + content = json.loads(response.content.decode("utf-8")) + assertion = self.assertTrue if action == "allow" else self.assertFalse + assertion(any(user.username in x["username"] for x in content["results"])) @ddt.ddt @httpretty.activate @override_waffle_flag(ENABLE_DISCUSSIONS_MFE, True) -class CourseActivityStatsTest(ForumsEnableMixin, UrlResetMixin, CommentsServiceMockMixin, APITestCase, - SharedModuleStoreTestCase): +class CourseActivityStatsTest( + ForumsEnableMixin, + UrlResetMixin, + CommentsServiceMockMixin, + APITestCase, + SharedModuleStoreTestCase, +): """ Tests for the course stats endpoint """ - @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) def setUp(self) -> None: super().setUp() patcher = mock.patch( - 'openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled', - return_value=False + "openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled", + return_value=False, ) patcher.start() self.addCleanup(patcher.stop) self.course = CourseFactory.create() self.course_key = str(self.course.id) seed_permissions_roles(self.course.id) - self.user = UserFactory(username='user') - self.moderator = UserFactory(username='moderator') + self.user = UserFactory(username="user") + self.moderator = UserFactory(username="moderator") moderator_role = Role.objects.get(name="Moderator", course_id=self.course.id) moderator_role.users.add(self.moderator) self.stats = [ { - "active_flags": random.randint(0, 3), - "inactive_flags": random.randint(0, 2), + "threads": random.randint(0, 10), "replies": random.randint(0, 30), "responses": random.randint(0, 100), - "threads": random.randint(0, 10), - "username": f"user-{idx}" + "deleted_threads": 0, + "deleted_replies": 0, + "deleted_responses": 0, + "active_flags": random.randint(0, 3), + "inactive_flags": random.randint(0, 2), + "username": f"user-{idx}", } for idx in range(10) ] for stat in self.stats: user = UserFactory.create( - username=stat['username'], + username=stat["username"], email=f"{stat['username']}@example.com", - password=self.TEST_PASSWORD + password=self.TEST_PASSWORD, ) - CourseEnrollment.enroll(user, self.course.id, mode='audit') + CourseEnrollment.enroll(user, self.course.id, mode="audit") - CourseEnrollment.enroll(self.moderator, self.course.id, mode='audit') - self.stats_without_flags = [{**stat, "active_flags": None, "inactive_flags": None} for stat in self.stats] + CourseEnrollment.enroll(self.moderator, self.course.id, mode="audit") + self.stats_without_flags = [ + {**stat, "active_flags": None, "inactive_flags": None} + for stat in self.stats + ] self.register_course_stats_response(self.course_key, self.stats, 1, 3) - self.url = reverse("discussion_course_activity_stats", kwargs={"course_key_string": self.course_key}) + self.url = reverse( + "discussion_course_activity_stats", + kwargs={"course_key_string": self.course_key}, + ) - @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) def test_regular_user(self): """ Tests that for a regular user stats are returned without flag counts @@ -1933,7 +2165,9 @@ def test_regular_user(self): data = response.json() assert data["results"] == self.stats_without_flags - @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) def test_moderator_user(self): """ Tests that for a moderator user stats are returned with flag counts @@ -1953,25 +2187,40 @@ def test_moderator_user(self): ("user", "recency", "recency"), ) @ddt.unpack - @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) - def test_sorting(self, username, ordering_requested, ordering_performed): + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) + @mock.patch("lms.djangoapps.discussion.rest_api.api.get_course_user_stats") + def test_sorting( + self, + username, + ordering_requested, + ordering_performed, + mock_get_course_user_stats, + ): """ Test valid sorting options and defaults """ + mock_get_course_user_stats.return_value = { + "user_stats": [], + "page": 1, + "num_pages": 1, + "count": 0, + } self.client.login(username=username, password=self.TEST_PASSWORD) params = {} if ordering_requested: params = {"order_by": ordering_requested} self.client.get(self.url, params) - assert urlparse( - httpretty.last_request().path # lint-amnesty, pylint: disable=no-member - ).path == f"/api/v1/users/{self.course_key}/stats" - assert parse_qs( - urlparse(httpretty.last_request().path).query # lint-amnesty, pylint: disable=no-member - ).get("sort_key", None) == [ordering_performed] + + call_args, call_kwargs = mock_get_course_user_stats.call_args + called_params = call_kwargs.get("params") or call_args[1] + assert called_params.get("sort_key") == ordering_performed @ddt.data("flagged", "xyz") - @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) def test_sorting_error_regular_user(self, order_by): """ Test for invalid sorting options for regular users. @@ -1981,47 +2230,225 @@ def test_sorting_error_regular_user(self, order_by): assert "order_by" in response.json()["field_errors"] @ddt.data( - ('user', 'user-0,user-1,user-2,user-3,user-4,user-5,user-6,user-7,user-8,user-9'), - ('moderator', 'moderator'), + ( + "user", + "user-0,user-1,user-2,user-3,user-4,user-5,user-6,user-7,user-8,user-9", + ), + ("moderator", "moderator"), ) @ddt.unpack - @mock.patch.dict("django.conf.settings.FEATURES", {'ENABLE_DISCUSSION_SERVICE': True}) - def test_with_username_param(self, username_search_string, comma_separated_usernames): + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) + def test_with_username_param( + self, username_search_string, comma_separated_usernames + ): """ Test for endpoint with username param. """ - params = {'username': username_search_string} + params = {"username": username_search_string} self.client.login(username=self.moderator.username, password=self.TEST_PASSWORD) self.client.get(self.url, params) - assert urlparse( - httpretty.last_request().path # lint-amnesty, pylint: disable=no-member - ).path == f'/api/v1/users/{self.course_key}/stats' + assert ( + urlparse( + httpretty.last_request().path # lint-amnesty, pylint: disable=no-member + ).path + == f"/api/v1/users/{self.course_key}/stats" + ) assert parse_qs( - urlparse(httpretty.last_request().path).query # lint-amnesty, pylint: disable=no-member - ).get('usernames', [None]) == [comma_separated_usernames] + urlparse( + httpretty.last_request().path + ).query # lint-amnesty, pylint: disable=no-member + ).get("usernames", [None]) == [comma_separated_usernames] - @mock.patch.dict("django.conf.settings.FEATURES", {'ENABLE_DISCUSSION_SERVICE': True}) + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) def test_with_username_param_with_no_matches(self): """ Test for endpoint with username param with no matches. """ - params = {'username': 'unknown'} + params = {"username": "unknown"} self.client.login(username=self.moderator.username, password=self.TEST_PASSWORD) response = self.client.get(self.url, params) data = response.json() - self.assertFalse(data['results']) - assert data['pagination']['count'] == 0 + self.assertFalse(data["results"]) + assert data["pagination"]["count"] == 0 - @ddt.data( - 'user-0', - 'USER-1', - 'User-2', - 'UsEr-3' + @ddt.data("user-0", "USER-1", "User-2", "UsEr-3") + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} ) - @mock.patch.dict("django.conf.settings.FEATURES", {'ENABLE_DISCUSSION_SERVICE': True}) def test_with_username_param_case(self, username_search_string): """ Test user search function is case-insensitive. """ - response = get_usernames_from_search_string(self.course_key, username_search_string, 1, 1) + response = get_usernames_from_search_string( + self.course_key, username_search_string, 1, 1 + ) assert response == (username_search_string.lower(), 1, 1) + + @mock.patch.dict( + "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True} + ) + def test_banned_username_lookup_error_fails_open(self): + """Stats endpoint should not fail when banned-username lookup backend errors.""" + self.client.login(username=self.user.username, password=self.TEST_PASSWORD) + with mock.patch( + "lms.djangoapps.discussion.rest_api.api.ENABLE_DISCUSSION_BAN.is_enabled", + return_value=True, + ), mock.patch( + "lms.djangoapps.discussion.rest_api.api.forum_api.get_banned_usernames", + side_effect=CommentClientRequestError("temporary backend failure"), + create=True, + ): + response = self.client.get(self.url) + + assert response.status_code == 200 + data = response.json() + assert data["results"] == self.stats_without_flags + + +@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) +@ddt.ddt +class DiscussionModerationViewSetUnitTests(APITestCase): + """Unit tests for DiscussionModerationViewSet helper behavior.""" + + class _DiscussionModerationViewSetTestProxy(DiscussionModerationViewSet): + """Test proxy exposing internal helper for unit testing.""" + + def get_or_create_ban_proxy(self, user, course_key, ban_scope, reason, request): + return self._get_or_create_ban(user, course_key, ban_scope, reason, request) + + def validate_ban_request_proxy(self, request, serializer_data): + return self._validate_ban_request_and_get_user(request, serializer_data) + + def setUp(self): + super().setUp() + self.viewset = self._DiscussionModerationViewSetTestProxy() + self.user = UserFactory.create() + self.moderator = UserFactory.create() + self.request = mock.Mock(user=self.moderator) + self.course_key = CourseKey.from_string("course-v1:x+y+z") + + @ddt.data(("course", False), ("organization", True)) + @ddt.unpack + def test_get_or_create_ban_uses_expected_check_org(self, ban_scope, check_org): + with mock.patch("forum.api.is_user_banned", return_value=False, create=True) as is_user_banned, mock.patch( + "forum.api.ban_user", + return_value={"id": 1, "reactivated": False}, + create=True, + ): + self.viewset.get_or_create_ban_proxy( + user=self.user, + course_key=self.course_key, + ban_scope=ban_scope, + reason="", + request=self.request, + ) + + is_user_banned.assert_called_once_with( + self.user, + self.course_key, + check_org=check_org, + ) + + def test_validate_ban_request_invalid_course_id_returns_400(self): + result = self.viewset.validate_ban_request_proxy( + request=self.request, + serializer_data={ + "user_id": self.user.id, + "course_id": "invalid-course-id", + "scope": "course", + "reason": "", + }, + ) + + assert result.status_code == status.HTTP_400_BAD_REQUEST + assert result.data == {"error": "Invalid course_id: invalid-course-id"} + + def test_bulk_delete_ban_invalid_course_id_returns_400(self): + request = mock.Mock(user=self.moderator, data={}) + serializer_instance = mock.Mock() + serializer_instance.is_valid.return_value = True + serializer_instance.validated_data = { + "user_id": self.user.id, + "course_id": "invalid-course-id", + "ban_user": False, + "ban_scope": "course", + "reason": "", + } + + with mock.patch( + "lms.djangoapps.discussion.rest_api.serializers.BulkDeleteBanRequestSerializer", + return_value=serializer_instance, + ): + response = self.viewset.bulk_delete_ban(request) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.data == {"error": "Invalid course_id: invalid-course-id"} + + def test_banned_users_invalid_course_id_returns_400(self): + request = mock.Mock(user=self.moderator, query_params={}) + + response = self.viewset.banned_users(request, course_id="invalid-course-id") + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.data == {"error": "Invalid course_id: invalid-course-id"} + + def test_unban_user_by_id_invalid_course_id_returns_400(self): + request = mock.Mock( + user=self.moderator, + data={"course_id": "invalid-course-id", "reason": "appeal approved"}, + ) + + with mock.patch( + "forum.api.get_ban", + return_value={"is_active": True, "course_id": None, "scope": "organization", "org_key": "x"}, + create=True, + ): + response = self.viewset.unban_user_by_id(request, pk=1) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.data == {"error": "Invalid course_id: invalid-course-id"} + + +@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) +class BulkDeleteBanRequestSerializerUnitTests(APITestCase): + """Unit tests for BulkDeleteBanRequestSerializer validation behavior.""" + + def setUp(self): + super().setUp() + self.target_user = UserFactory.create() + self.course_id = "course-v1:x+y+z" + + def test_org_scope_accepts_is_staff_when_ban_user_true(self): + acting_user = UserFactory.create(is_staff=True) + request = mock.Mock(user=acting_user) + serializer = BulkDeleteBanRequestSerializer( + data={ + "user_id": self.target_user.id, + "course_id": self.course_id, + "ban_user": True, + "ban_scope": "organization", + "reason": "policy violation", + }, + context={"request": request}, + ) + + assert serializer.is_valid(), serializer.errors + + def test_org_scope_skips_permission_check_when_ban_user_false(self): + acting_user = UserFactory.create(is_staff=False) + request = mock.Mock(user=acting_user) + serializer = BulkDeleteBanRequestSerializer( + data={ + "user_id": self.target_user.id, + "course_id": self.course_id, + "ban_user": False, + "ban_scope": "organization", + }, + context={"request": request}, + ) + + assert serializer.is_valid(), serializer.errors diff --git a/lms/djangoapps/discussion/rest_api/tests/test_views_v2.py b/lms/djangoapps/discussion/rest_api/tests/test_views_v2.py index 4247cbcab06c..c2f2f1877a79 100644 --- a/lms/djangoapps/discussion/rest_api/tests/test_views_v2.py +++ b/lms/djangoapps/discussion/rest_api/tests/test_views_v2.py @@ -14,8 +14,6 @@ from unittest import mock import ddt -from forum.backends.mongodb.comments import Comment -from forum.backends.mongodb.threads import CommentThread import httpretty from django.urls import reverse from pytz import UTC @@ -23,30 +21,39 @@ from rest_framework.parsers import JSONParser from rest_framework.test import APIClient -from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase -from xmodule.modulestore.tests.factories import CourseFactory from common.djangoapps.student.tests.factories import ( CourseEnrollmentFactory, UserFactory, ) from common.djangoapps.util.testing import PatchMediaTypeMixin, UrlResetMixin from common.test.utils import disable_signal -from lms.djangoapps.discussion.tests.utils import ( - make_minimal_cs_comment, - make_minimal_cs_thread, +from forum.backends.mongodb.comments import Comment +from forum.backends.mongodb.threads import CommentThread +from lms.djangoapps.discussion.django_comment_client.tests.utils import ( + ForumsEnableMixin, ) -from lms.djangoapps.discussion.django_comment_client.tests.utils import ForumsEnableMixin from lms.djangoapps.discussion.rest_api import api from lms.djangoapps.discussion.rest_api.tests.utils import ( ForumMockUtilsMixin, ProfileImageTestMixin, make_paginated_api_response, ) +from lms.djangoapps.discussion.tests.utils import ( + make_minimal_cs_comment, + make_minimal_cs_thread, +) from openedx.core.djangoapps.django_comment_common.models import ( - FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_COMMUNITY_TA, FORUM_ROLE_MODERATOR, FORUM_ROLE_STUDENT, - assign_role + FORUM_ROLE_ADMINISTRATOR, + FORUM_ROLE_COMMUNITY_TA, + FORUM_ROLE_MODERATOR, + FORUM_ROLE_STUDENT, + assign_role, ) -from openedx.core.djangoapps.user_api.accounts.image_helpers import get_profile_image_storage +from openedx.core.djangoapps.user_api.accounts.image_helpers import ( + get_profile_image_storage, +) +from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase +from xmodule.modulestore.tests.factories import CourseFactory class DiscussionAPIViewTestMixin(ForumsEnableMixin, ForumMockUtilsMixin, UrlResetMixin): @@ -112,6 +119,7 @@ def register_thread(self, overrides=None): "thread_type": "discussion", "title": "Test Title", "body": "Test body", + "is_deleted": False, } ) cs_thread.update(overrides or {}) @@ -330,6 +338,7 @@ def test_patch_read_non_owner_user(self): "voted", ], "response_count": 2, + "is_deleted": None, } ) assert response_data == expected_data @@ -359,6 +368,8 @@ def expected_response_data(self, overrides=None): "parent_id": None, "author": self.user.username, "author_label": None, + "is_author_banned": False, + "author_ban_scope": None, "created_at": "1970-01-01T00:00:00Z", "updated_at": "1970-01-01T00:00:00Z", "raw_body": "Original body", @@ -386,6 +397,11 @@ def expected_response_data(self, overrides=None): "image_url_medium": "http://testserver/static/default_50.png", "image_url_small": "http://testserver/static/default_30.png", }, + "learner_status": "new", + "is_deleted": False, + "deleted_at": None, + "deleted_by": None, + "deleted_by_label": None, } response_data.update(overrides or {}) return response_data @@ -500,6 +516,7 @@ def create_source_thread(self, overrides=None): "votes": {"up_count": 4}, "comments_count": 5, "unread_comments_count": 3, + "is_deleted": False, } ) @@ -511,15 +528,17 @@ def test_course_id_missing(self): self.assert_response_correct( response, 400, - {"field_errors": {"course_id": {"developer_message": "This field is required."}}} + { + "field_errors": { + "course_id": {"developer_message": "This field is required."} + } + }, ) def test_404(self): response = self.client.get(self.url, {"course_id": "non/existent/course"}) self.assert_response_correct( - response, - 404, - {"developer_message": "Course not found."} + response, 404, {"developer_message": "Course not found."} ) def test_basic(self): @@ -548,6 +567,7 @@ def test_basic(self): "voted", ], "abuse_flagged_count": None, + "is_deleted": None, } ) ] @@ -870,7 +890,9 @@ class BulkDeleteUserPostsTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): def setUp(self): super().setUp() - self.url = reverse("bulk_delete_user_posts", kwargs={"course_id": str(self.course.id)}) + self.url = reverse( + "bulk_delete_user_posts", kwargs={"course_id": str(self.course.id)} + ) self.user2 = UserFactory.create(password=self.password) CourseEnrollmentFactory.create(user=self.user2, course_id=self.course.id) @@ -886,13 +908,19 @@ def mock_comment_and_thread_count(self, comment_count=1, thread_count=1): thread_collection = mock.MagicMock() thread_collection.count_documents.return_value = thread_count patch_thread = mock.patch.object( - CommentThread, "_collection", new_callable=mock.PropertyMock, return_value=thread_collection + CommentThread, + "_collection", + new_callable=mock.PropertyMock, + return_value=thread_collection, ) comment_collection = mock.MagicMock() comment_collection.count_documents.return_value = comment_count patch_comment = mock.patch.object( - Comment, "_collection", new_callable=mock.PropertyMock, return_value=comment_collection + Comment, + "_collection", + new_callable=mock.PropertyMock, + return_value=comment_collection, ) thread_mock = patch_thread.start() @@ -907,7 +935,9 @@ def test_bulk_delete_denied_for_discussion_roles(self, role): """ Test bulk delete user posts denied with discussion roles. """ - thread_mock, comment_mock = self.mock_comment_and_thread_count(comment_count=1, thread_count=1) + thread_mock, comment_mock = self.mock_comment_and_thread_count( + comment_count=1, thread_count=1 + ) assign_role(self.course.id, self.user, role) response = self.client.post( f"{self.url}?username={self.user2.username}", @@ -931,7 +961,9 @@ def test_bulk_delete_allowed_for_discussion_roles(self, role): assert response.status_code == status.HTTP_202_ACCEPTED assert response.json() == {"comment_count": 1, "thread_count": 1} - @mock.patch('lms.djangoapps.discussion.rest_api.views.delete_course_post_for_user.apply_async') + @mock.patch( + "lms.djangoapps.discussion.rest_api.views.delete_course_post_for_user.apply_async" + ) @ddt.data(True, False) def test_task_only_runs_if_execute_param_is_true(self, execute, task_mock): """ diff --git a/lms/djangoapps/discussion/rest_api/tests/utils.py b/lms/djangoapps/discussion/rest_api/tests/utils.py index 342afb0ada5e..5d58ad9c3fac 100644 --- a/lms/djangoapps/discussion/rest_api/tests/utils.py +++ b/lms/djangoapps/discussion/rest_api/tests/utils.py @@ -2,7 +2,6 @@ Discussion API test utilities """ - import hashlib import json import re @@ -14,11 +13,18 @@ from PIL import Image from pytz import UTC -from lms.djangoapps.discussion.django_comment_client.tests.mixins import MockForumApiMixin -from openedx.core.djangoapps.django_comment_common.comment_client.utils import CommentClientRequestError +from lms.djangoapps.discussion.django_comment_client.tests.mixins import ( + MockForumApiMixin, +) +from openedx.core.djangoapps.django_comment_common.comment_client.utils import ( + CommentClientRequestError, +) from openedx.core.djangoapps.profile_images.images import create_profile_images from openedx.core.djangoapps.profile_images.tests.helpers import make_image_file -from openedx.core.djangoapps.user_api.accounts.image_helpers import get_profile_image_names, set_has_profile_image +from openedx.core.djangoapps.user_api.accounts.image_helpers import ( + get_profile_image_names, + set_has_profile_image, +) def _get_thread_callback(thread_data): @@ -26,6 +32,7 @@ def _get_thread_callback(thread_data): Get a callback function that will return POST/PUT data overridden by response_overrides. """ + def callback(request, _uri, headers): """ Simulate the thread creation or update endpoint by returning the provided @@ -42,7 +49,7 @@ def callback(request, _uri, headers): response_data["edit_history"] = [ { "original_body": original_data["body"], - "author": thread_data.get('username'), + "author": thread_data.get("username"), "reason_code": val, }, ] @@ -68,11 +75,13 @@ def callback(*args, **kwargs): if key in ["anonymous", "anonymous_to_peers", "closed", "pinned"]: response_data[key] = val is True or val == "True" elif key == "edit_reason_code": - response_data["edit_history"] = [{ - "original_body": original_data["body"], - "author": thread_data.get("username"), - "reason_code": val, - }] + response_data["edit_history"] = [ + { + "original_body": original_data["body"], + "author": thread_data.get("username"), + "reason_code": val, + } + ] else: response_data[key] = val @@ -87,6 +96,7 @@ def _get_comment_callback(comment_data, thread_id, parent_id): plus necessary dummy data, overridden by the content of the POST/PUT request. """ + def callback(request, _uri, headers): """ Simulate the comment creation or update endpoint as described above. @@ -105,7 +115,7 @@ def callback(request, _uri, headers): response_data["edit_history"] = [ { "original_body": original_data["body"], - "author": comment_data.get('username'), + "author": comment_data.get("username"), "reason_code": val, }, ] @@ -135,11 +145,13 @@ def callback(*args, **kwargs): if key in ["anonymous", "anonymous_to_peers", "endorsed"]: response_data[key] = val is True or val == "True" elif key == "edit_reason_code": - response_data["edit_history"] = [{ - "original_body": original_data["body"], - "author": comment_data.get("username"), - "reason_code": val, - }] + response_data["edit_history"] = [ + { + "original_body": original_data["body"], + "author": comment_data.get("username"), + "reason_code": val, + } + ] else: response_data[key] = val @@ -152,9 +164,11 @@ def make_user_callbacks(user_map): """ Returns a callable that mimics user creation. """ + def callback(*args, **kwargs): - user_id = args[0] if args else kwargs.get('user_id') + user_id = args[0] if args else kwargs.get("user_id") return user_map[str(user_id)] + return callback @@ -163,54 +177,58 @@ class CommentsServiceMockMixin: def register_get_threads_response(self, threads, page, num_pages): """Register a mock response for GET on the CS thread list endpoint""" - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.GET, "http://localhost:4567/api/v1/threads", - body=json.dumps({ - "collection": threads, - "page": page, - "num_pages": num_pages, - "thread_count": len(threads), - }), - status=200 + body=json.dumps( + { + "collection": threads, + "page": page, + "num_pages": num_pages, + "thread_count": len(threads), + } + ), + status=200, ) def register_get_course_commentable_counts_response(self, course_id, thread_counts): """Register a mock response for GET on the CS thread list endpoint""" - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.GET, f"http://localhost:4567/api/v1/commentables/{course_id}/counts", body=json.dumps(thread_counts), - status=200 + status=200, ) def register_get_threads_search_response(self, threads, rewrite, num_pages=1): """Register a mock response for GET on the CS thread search endpoint""" - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.GET, "http://localhost:4567/api/v1/search/threads", - body=json.dumps({ - "collection": threads, - "page": 1, - "num_pages": num_pages, - "corrected_text": rewrite, - "thread_count": len(threads), - }), - status=200 + body=json.dumps( + { + "collection": threads, + "page": 1, + "num_pages": num_pages, + "corrected_text": rewrite, + "thread_count": len(threads), + } + ), + status=200, ) def register_post_thread_response(self, thread_data): """Register a mock response for POST on the CS commentable endpoint""" - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.POST, re.compile(r"http://localhost:4567/api/v1/(\w+)/threads"), - body=_get_thread_callback(thread_data) + body=_get_thread_callback(thread_data), ) def register_put_thread_response(self, thread_data): @@ -218,49 +236,51 @@ def register_put_thread_response(self, thread_data): Register a mock response for PUT on the CS endpoint for the given thread_id. """ - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.PUT, "http://localhost:4567/api/v1/threads/{}".format(thread_data["id"]), - body=_get_thread_callback(thread_data) + body=_get_thread_callback(thread_data), ) def register_get_thread_error_response(self, thread_id, status_code): """Register a mock error response for GET on the CS thread endpoint.""" - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.GET, f"http://localhost:4567/api/v1/threads/{thread_id}", body="", - status=status_code + status=status_code, ) def register_get_thread_response(self, thread): """ Register a mock response for GET on the CS thread instance endpoint. """ - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.GET, "http://localhost:4567/api/v1/threads/{id}".format(id=thread["id"]), body=json.dumps(thread), - status=200 + status=200, ) def register_get_comments_response(self, comments, page, num_pages): """Register a mock response for GET on the CS comments list endpoint""" - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.GET, "http://localhost:4567/api/v1/comments", - body=json.dumps({ - "collection": comments, - "page": page, - "num_pages": num_pages, - "comment_count": len(comments), - }), - status=200 + body=json.dumps( + { + "collection": comments, + "page": page, + "num_pages": num_pages, + "comment_count": len(comments), + } + ), + status=200, ) def register_post_comment_response(self, comment_data, thread_id, parent_id=None): @@ -274,11 +294,11 @@ def register_post_comment_response(self, comment_data, thread_id, parent_id=None else: url = f"http://localhost:4567/api/v1/threads/{thread_id}/comments" - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.POST, url, - body=_get_comment_callback(comment_data, thread_id, parent_id) + body=_get_comment_callback(comment_data, thread_id, parent_id), ) def register_put_comment_response(self, comment_data): @@ -288,11 +308,11 @@ def register_put_comment_response(self, comment_data): """ thread_id = comment_data["thread_id"] parent_id = comment_data.get("parent_id") - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.PUT, "http://localhost:4567/api/v1/comments/{}".format(comment_data["id"]), - body=_get_comment_callback(comment_data, thread_id, parent_id) + body=_get_comment_callback(comment_data, thread_id, parent_id), ) def register_get_comment_error_response(self, comment_id, status_code): @@ -300,12 +320,12 @@ def register_get_comment_error_response(self, comment_id, status_code): Register a mock error response for GET on the CS comment instance endpoint. """ - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.GET, f"http://localhost:4567/api/v1/comments/{comment_id}", body="", - status=status_code + status=status_code, ) def register_get_comment_response(self, response_overrides): @@ -313,75 +333,83 @@ def register_get_comment_response(self, response_overrides): Register a mock response for GET on the CS comment instance endpoint. """ comment = make_minimal_cs_comment(response_overrides) - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.GET, "http://localhost:4567/api/v1/comments/{id}".format(id=comment["id"]), body=json.dumps(comment), - status=200 + status=200, ) - def register_get_user_response(self, user, subscribed_thread_ids=None, upvoted_ids=None): + def register_get_user_response( + self, user, subscribed_thread_ids=None, upvoted_ids=None + ): """Register a mock response for GET on the CS user instance endpoint""" - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.GET, f"http://localhost:4567/api/v1/users/{user.id}", - body=json.dumps({ - "id": str(user.id), - "subscribed_thread_ids": subscribed_thread_ids or [], - "upvoted_ids": upvoted_ids or [], - }), - status=200 + body=json.dumps( + { + "id": str(user.id), + "subscribed_thread_ids": subscribed_thread_ids or [], + "upvoted_ids": upvoted_ids or [], + } + ), + status=200, ) def register_get_user_retire_response(self, user, status=200, body=""): """Register a mock response for GET on the CS user retirement endpoint""" - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.POST, f"http://localhost:4567/api/v1/users/{user.id}/retire", body=body, - status=status + status=status, ) def register_get_username_replacement_response(self, user, status=200, body=""): - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.POST, f"http://localhost:4567/api/v1/users/{user.id}/replace_username", body=body, - status=status + status=status, ) def register_subscribed_threads_response(self, user, threads, page, num_pages): """Register a mock response for GET on the CS user instance endpoint""" - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.GET, f"http://localhost:4567/api/v1/users/{user.id}/subscribed_threads", - body=json.dumps({ - "collection": threads, - "page": page, - "num_pages": num_pages, - "thread_count": len(threads), - }), - status=200 + body=json.dumps( + { + "collection": threads, + "page": page, + "num_pages": num_pages, + "thread_count": len(threads), + } + ), + status=200, ) def register_course_stats_response(self, course_key, stats, page, num_pages): """Register a mock response for GET on the CS user course stats instance endpoint""" - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.GET, f"http://localhost:4567/api/v1/users/{course_key}/stats", - body=json.dumps({ - "user_stats": stats, - "page": page, - "num_pages": num_pages, - "count": len(stats), - }), - status=200 + body=json.dumps( + { + "user_stats": stats, + "page": page, + "num_pages": num_pages, + "count": len(stats), + } + ), + status=200, ) def register_subscription_response(self, user): @@ -389,13 +417,13 @@ def register_subscription_response(self, user): Register a mock response for POST and DELETE on the CS user subscription endpoint """ - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." for method in [httpretty.POST, httpretty.DELETE]: httpretty.register_uri( method, f"http://localhost:4567/api/v1/users/{user.id}/subscriptions", body=json.dumps({}), # body is unused - status=200 + status=200, ) def register_thread_votes_response(self, thread_id): @@ -403,13 +431,13 @@ def register_thread_votes_response(self, thread_id): Register a mock response for PUT and DELETE on the CS thread votes endpoint """ - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." for method in [httpretty.PUT, httpretty.DELETE]: httpretty.register_uri( method, f"http://localhost:4567/api/v1/threads/{thread_id}/votes", body=json.dumps({}), # body is unused - status=200 + status=200, ) def register_comment_votes_response(self, comment_id): @@ -417,41 +445,39 @@ def register_comment_votes_response(self, comment_id): Register a mock response for PUT and DELETE on the CS comment votes endpoint """ - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." for method in [httpretty.PUT, httpretty.DELETE]: httpretty.register_uri( method, f"http://localhost:4567/api/v1/comments/{comment_id}/votes", body=json.dumps({}), # body is unused - status=200 + status=200, ) def register_flag_response(self, content_type, content_id): """Register a mock response for PUT on the CS flag endpoints""" - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." for path in ["abuse_flag", "abuse_unflag"]: httpretty.register_uri( "PUT", "http://localhost:4567/api/v1/{content_type}s/{content_id}/{path}".format( - content_type=content_type, - content_id=content_id, - path=path + content_type=content_type, content_id=content_id, path=path ), body=json.dumps({}), # body is unused - status=200 + status=200, ) def register_read_response(self, user, content_type, content_id): """ Register a mock response for POST on the CS 'read' endpoint """ - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.POST, f"http://localhost:4567/api/v1/users/{user.id}/read", - params={'source_type': content_type, 'source_id': content_id}, + params={"source_type": content_type, "source_id": content_id}, body=json.dumps({}), # body is unused - status=200 + status=200, ) def register_thread_flag_response(self, thread_id): @@ -466,48 +492,48 @@ def register_delete_thread_response(self, thread_id): """ Register a mock response for DELETE on the CS thread instance endpoint """ - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.DELETE, f"http://localhost:4567/api/v1/threads/{thread_id}", body=json.dumps({}), # body is unused - status=200 + status=200, ) def register_delete_comment_response(self, comment_id): """ Register a mock response for DELETE on the CS comment instance endpoint """ - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.DELETE, f"http://localhost:4567/api/v1/comments/{comment_id}", body=json.dumps({}), # body is unused - status=200 + status=200, ) def register_user_active_threads(self, user_id, response): """ Register a mock response for GET on the CS comment active threads endpoint """ - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.GET, f"http://localhost:4567/api/v1/users/{user_id}/active_threads", body=json.dumps(response), - status=200 + status=200, ) def register_get_subscriptions(self, thread_id, response): """ Register a mock response for GET on the CS comment active threads endpoint """ - assert httpretty.is_enabled(), 'httpretty must be enabled to mock calls.' + assert httpretty.is_enabled(), "httpretty must be enabled to mock calls." httpretty.register_uri( httpretty.GET, f"http://localhost:4567/api/v1/threads/{thread_id}/subscriptions", body=json.dumps(response), - status=200 + status=200, ) def assert_query_params_equal(self, httpretty_request, expected_params): @@ -531,7 +557,7 @@ def request_patch(self, request_data): return self.client.patch( self.url, json.dumps(request_data), - content_type="application/merge-patch+json" + content_type="application/merge-patch+json", ) def expected_thread_data(self, overrides=None): @@ -543,6 +569,8 @@ def expected_thread_data(self, overrides=None): "anonymous_to_peers": False, "author": self.user.username, "author_label": None, + "is_author_banned": False, + "author_ban_scope": None, "created_at": "1970-01-01T00:00:00Z", "updated_at": "1970-01-01T00:00:00Z", "raw_body": "Test body", @@ -588,6 +616,11 @@ def expected_thread_data(self, overrides=None): "closed_by_label": None, "close_reason": None, "close_reason_code": None, + "learner_status": "new", + "is_deleted": False, + "deleted_at": None, + "deleted_by": None, + "deleted_by_label": None, } response_data.update(overrides or {}) return response_data @@ -598,137 +631,153 @@ class ForumMockUtilsMixin(MockForumApiMixin): def register_get_threads_response(self, threads, page, num_pages): """Register a mock response for GET on the CS thread list endpoint""" - self.set_mock_return_value('get_user_threads', { - "collection": threads, - "page": page, - "num_pages": num_pages, - "thread_count": len(threads), - }) + self.set_mock_return_value( + "get_user_threads", + { + "collection": threads, + "page": page, + "num_pages": num_pages, + "thread_count": len(threads), + }, + ) def register_get_course_commentable_counts_response(self, course_id, thread_counts): """Register a mock response for GET on the CS thread list endpoint""" - self.set_mock_return_value('get_commentables_stats', thread_counts) + self.set_mock_return_value("get_commentables_stats", thread_counts) def register_get_threads_search_response(self, threads, rewrite, num_pages=1): """Register a mock response for GET on the CS thread search endpoint""" - self.set_mock_return_value('search_threads', { - "collection": threads, - "page": 1, - "num_pages": num_pages, - "corrected_text": rewrite, - "thread_count": len(threads), - }) + self.set_mock_return_value( + "search_threads", + { + "collection": threads, + "page": 1, + "num_pages": num_pages, + "corrected_text": rewrite, + "thread_count": len(threads), + }, + ) def register_post_thread_response(self, thread_data): - self.set_mock_side_effect('create_thread', make_thread_callback(thread_data)) + self.set_mock_side_effect("create_thread", make_thread_callback(thread_data)) def register_put_thread_response(self, thread_data): - self.set_mock_side_effect('update_thread', make_thread_callback(thread_data)) + self.set_mock_side_effect("update_thread", make_thread_callback(thread_data)) def register_get_thread_error_response(self, thread_id, status_code): self.set_mock_side_effect( - 'get_thread', - CommentClientRequestError(f"Thread does not exist with Id: {thread_id}") + "get_thread", + CommentClientRequestError(f"Thread does not exist with Id: {thread_id}"), ) def register_get_thread_response(self, thread): - self.set_mock_return_value('get_thread', thread) + self.set_mock_return_value("get_thread", thread) def register_get_comments_response(self, comments, page, num_pages): - self.set_mock_return_value('get_parent_comment', { - "collection": comments, - "page": page, - "num_pages": num_pages, - "comment_count": len(comments), - }) + self.set_mock_return_value( + "get_parent_comment", + { + "collection": comments, + "page": page, + "num_pages": num_pages, + "comment_count": len(comments), + }, + ) def register_post_comment_response(self, comment_data, thread_id, parent_id=None): self.set_mock_side_effect( - 'create_child_comment' if parent_id else 'create_parent_comment', - make_comment_callback(comment_data, thread_id, parent_id) + "create_child_comment" if parent_id else "create_parent_comment", + make_comment_callback(comment_data, thread_id, parent_id), ) def register_put_comment_response(self, comment_data): thread_id = comment_data["thread_id"] parent_id = comment_data.get("parent_id") self.set_mock_side_effect( - 'update_comment', - make_comment_callback(comment_data, thread_id, parent_id) + "update_comment", make_comment_callback(comment_data, thread_id, parent_id) ) def register_get_comment_error_response(self, comment_id, status_code): self.set_mock_side_effect( - 'get_parent_comment', - CommentClientRequestError(f"Comment does not exist with Id: {comment_id}") + "get_parent_comment", + CommentClientRequestError(f"Comment does not exist with Id: {comment_id}"), ) def register_get_comment_response(self, response_overrides): comment = make_minimal_cs_comment(response_overrides) - self.set_mock_return_value('get_parent_comment', comment) + self.set_mock_return_value("get_parent_comment", comment) - def register_get_user_response(self, user, subscribed_thread_ids=None, upvoted_ids=None): + def register_get_user_response( + self, user, subscribed_thread_ids=None, upvoted_ids=None + ): """Register a mock response for GET on the CS user endpoint""" self.users_map[str(user.id)] = { "id": str(user.id), "subscribed_thread_ids": subscribed_thread_ids or [], "upvoted_ids": upvoted_ids or [], } - self.set_mock_side_effect('get_user', make_user_callbacks(self.users_map)) + self.set_mock_side_effect("get_user", make_user_callbacks(self.users_map)) def register_get_user_retire_response(self, user, body=""): - self.set_mock_return_value('retire_user', body) + self.set_mock_return_value("retire_user", body) def register_get_username_replacement_response(self, user, status=200, body=""): - self.set_mock_return_value('update_username', body) + self.set_mock_return_value("update_username", body) def register_subscribed_threads_response(self, user, threads, page, num_pages): - self.set_mock_return_value('get_user_subscriptions', { - "collection": threads, - "page": page, - "num_pages": num_pages, - "thread_count": len(threads), - }) + self.set_mock_return_value( + "get_user_subscriptions", + { + "collection": threads, + "page": page, + "num_pages": num_pages, + "thread_count": len(threads), + }, + ) def register_course_stats_response(self, course_key, stats, page, num_pages): - self.set_mock_return_value('get_user_course_stats', { - "user_stats": stats, - "page": page, - "num_pages": num_pages, - "count": len(stats), - }) + self.set_mock_return_value( + "get_user_course_stats", + { + "user_stats": stats, + "page": page, + "num_pages": num_pages, + "count": len(stats), + }, + ) def register_subscription_response(self, user): - self.set_mock_return_value('create_subscription', {}) - self.set_mock_return_value('delete_subscription', {}) + self.set_mock_return_value("create_subscription", {}) + self.set_mock_return_value("delete_subscription", {}) def register_thread_votes_response(self, thread_id): - self.set_mock_return_value('update_thread_votes', {}) - self.set_mock_return_value('delete_thread_vote', {}) + self.set_mock_return_value("update_thread_votes", {}) + self.set_mock_return_value("delete_thread_vote", {}) def register_comment_votes_response(self, comment_id): - self.set_mock_return_value('update_comment_votes', {}) - self.set_mock_return_value('delete_comment_vote', {}) + self.set_mock_return_value("update_comment_votes", {}) + self.set_mock_return_value("delete_comment_vote", {}) def register_flag_response(self, content_type, content_id): - if content_type == 'thread': - self.set_mock_return_value('update_thread_flag', {}) - elif content_type == 'comment': - self.set_mock_return_value('update_comment_flag', {}) + if content_type == "thread": + self.set_mock_return_value("update_thread_flag", {}) + elif content_type == "comment": + self.set_mock_return_value("update_comment_flag", {}) def register_read_response(self, user, content_type, content_id): - self.set_mock_return_value('mark_thread_as_read', {}) + self.set_mock_return_value("mark_thread_as_read", {}) def register_delete_thread_response(self, thread_id): - self.set_mock_return_value('delete_thread', {}) + self.set_mock_return_value("delete_thread", {}) def register_delete_comment_response(self, comment_id): - self.set_mock_return_value('delete_comment', {}) + self.set_mock_return_value("delete_comment", {}) def register_user_active_threads(self, user_id, response): - self.set_mock_return_value('get_user_active_threads', response) + self.set_mock_return_value("get_user_active_threads", response) def register_get_subscriptions(self, thread_id, response): - self.set_mock_return_value('get_thread_subscriptions', response) + self.set_mock_return_value("get_thread_subscriptions", response) def register_thread_flag_response(self, thread_id): """Register a mock response for PUT on the CS thread flag endpoints""" @@ -759,7 +808,7 @@ def request_patch(self, request_data): return self.client.patch( self.url, json.dumps(request_data), - content_type="application/merge-patch+json" + content_type="application/merge-patch+json", ) def expected_thread_data(self, overrides=None): @@ -771,6 +820,8 @@ def expected_thread_data(self, overrides=None): "anonymous_to_peers": False, "author": self.user.username, "author_label": None, + "is_author_banned": False, + "author_ban_scope": None, "created_at": "1970-01-01T00:00:00Z", "updated_at": "1970-01-01T00:00:00Z", "raw_body": "Test body", @@ -816,6 +867,11 @@ def expected_thread_data(self, overrides=None): "closed_by_label": None, "close_reason": None, "close_reason_code": None, + "learner_status": "new", + "is_deleted": False, + "deleted_at": None, + "deleted_by": None, + "deleted_by_label": None, } response_data.update(overrides or {}) return response_data @@ -888,7 +944,9 @@ def make_minimal_cs_comment(overrides=None): return ret -def make_paginated_api_response(results=None, count=0, num_pages=0, next_link=None, previous_link=None): +def make_paginated_api_response( + results=None, count=0, num_pages=0, next_link=None, previous_link=None +): """ Generates the response dictionary of paginated APIs with passed data """ @@ -899,7 +957,7 @@ def make_paginated_api_response(results=None, count=0, num_pages=0, next_link=No "count": count, "num_pages": num_pages, }, - "results": results or [] + "results": results or [], } @@ -917,7 +975,9 @@ def create_profile_image(self, user, storage): with make_image_file() as image_file: create_profile_images(image_file, get_profile_image_names(user.username)) self.check_images(user, storage) - set_has_profile_image(user.username, True, self.TEST_PROFILE_IMAGE_UPLOADED_AT) + set_has_profile_image( + user.username, True, self.TEST_PROFILE_IMAGE_UPLOADED_AT + ) def check_images(self, user, storage, exist=True): """ @@ -931,7 +991,7 @@ def check_images(self, user, storage, exist=True): assert storage.exists(name) with closing(Image.open(storage.path(name))) as img: assert img.size == (size, size) - assert img.format == 'JPEG' + assert img.format == "JPEG" else: assert not storage.exists(name) @@ -939,18 +999,18 @@ def get_expected_user_profile(self, username): """ Returns the expected user profile data for a given username """ - url = 'http://example-storage.com/profile-images/{filename}_{{size}}.jpg?v={timestamp}'.format( - filename=hashlib.md5(b'secret' + username.encode('utf-8')).hexdigest(), - timestamp=self.TEST_PROFILE_IMAGE_UPLOADED_AT.strftime("%s") + url = "http://example-storage.com/profile-images/{filename}_{{size}}.jpg?v={timestamp}".format( + filename=hashlib.md5(b"secret" + username.encode("utf-8")).hexdigest(), + timestamp=self.TEST_PROFILE_IMAGE_UPLOADED_AT.strftime("%s"), ) return { - 'profile': { - 'image': { - 'has_image': True, - 'image_url_full': url.format(size=500), - 'image_url_large': url.format(size=120), - 'image_url_medium': url.format(size=50), - 'image_url_small': url.format(size=30), + "profile": { + "image": { + "has_image": True, + "image_url_full": url.format(size=500), + "image_url_large": url.format(size=120), + "image_url_medium": url.format(size=50), + "image_url_small": url.format(size=30), } } } @@ -960,14 +1020,14 @@ def parsed_body(request): """Returns a parsed dictionary version of a request body""" # This could just be HTTPrettyRequest.parsed_body, but that method double-decodes '%2B' -> '+' -> ' '. # You can just remove this method when this issue is fixed: https://github.com/gabrielfalcao/HTTPretty/issues/240 - return parse_qs(request.body.decode('utf8')) + return parse_qs(request.body.decode("utf8")) def querystring(request): """Returns a parsed dictionary version of a query string""" # This could just be HTTPrettyRequest.querystring, but that method double-decodes '%2B' -> '+' -> ' '. # You can just remove this method when this issue is fixed: https://github.com/gabrielfalcao/HTTPretty/issues/240 - return parse_qs(request.path.split('?', 1)[-1]) + return parse_qs(request.path.split("?", 1)[-1]) class ThreadMock(object): @@ -975,7 +1035,9 @@ class ThreadMock(object): A mock thread object """ - def __init__(self, thread_id, creator, title, parent_id=None, body='', commentable_id=None): + def __init__( + self, thread_id, creator, title, parent_id=None, body="", commentable_id=None + ): self.id = thread_id self.user_id = str(creator.id) self.username = creator.username diff --git a/lms/djangoapps/discussion/rest_api/urls.py b/lms/djangoapps/discussion/rest_api/urls.py index f102dc41f249..7a7cbc4b15af 100644 --- a/lms/djangoapps/discussion/rest_api/urls.py +++ b/lms/djangoapps/discussion/rest_api/urls.py @@ -9,6 +9,7 @@ from lms.djangoapps.discussion.rest_api.views import ( BulkDeleteUserPosts, + BulkRestoreUserPosts, CommentViewSet, CourseActivityStatsView, CourseDiscussionRolesAPIView, @@ -18,8 +19,11 @@ CourseTopicsViewV3, CourseView, CourseViewV2, + DeletedContentView, + DiscussionModerationViewSet, LearnerThreadView, ReplaceUsernamesView, + RestoreContent, RetireUserView, ThreadViewSet, UploadFileView, @@ -30,27 +34,49 @@ ROUTER.register("comments", CommentViewSet, basename="comment") urlpatterns = [ + # Moderation endpoints (defined first to avoid router conflicts) + path( + 'v1/moderation/ban-user/', + DiscussionModerationViewSet.as_view({'post': 'ban_user'}), + name='discussion-moderation-ban-user' + ), + path( + 'v1/moderation/unban-user/', + DiscussionModerationViewSet.as_view({'post': 'unban_user'}), + name='discussion-moderation-unban-user' + ), + path( + 'v1/moderation//unban/', + DiscussionModerationViewSet.as_view({'post': 'unban_user_by_id'}), + name='discussion-moderation-unban-by-id' + ), + path( + 'v1/moderation/bulk-delete-ban/', + DiscussionModerationViewSet.as_view({'post': 'bulk_delete_ban'}), + name='discussion-moderation-bulk-delete-ban' + ), re_path( - r"^v1/courses/{}/settings$".format( - settings.COURSE_ID_PATTERN - ), + fr'^v1/moderation/banned-users/{settings.COURSE_ID_PATTERN}/?$', + DiscussionModerationViewSet.as_view({'get': 'banned_users'}), + name='discussion-moderation-banned-users' + ), + re_path( + r"^v1/courses/{}/settings$".format(settings.COURSE_ID_PATTERN), CourseDiscussionSettingsAPIView.as_view(), name="discussion_course_settings", ), re_path( - r"^v1/courses/{}/learner/$".format( - settings.COURSE_ID_PATTERN - ), + r"^v1/courses/{}/learner/$".format(settings.COURSE_ID_PATTERN), LearnerThreadView.as_view(), name="discussion_learner_threads", ), re_path( - fr"^v1/courses/{settings.COURSE_KEY_PATTERN}/activity_stats", + rf"^v1/courses/{settings.COURSE_KEY_PATTERN}/activity_stats", CourseActivityStatsView.as_view(), name="discussion_course_activity_stats", ), re_path( - fr"^v1/courses/{settings.COURSE_ID_PATTERN}/upload$", + rf"^v1/courses/{settings.COURSE_ID_PATTERN}/upload$", UploadFileView.as_view(), name="upload_file", ), @@ -62,36 +88,55 @@ name="discussion_course_roles", ), re_path( - fr"^v1/courses/{settings.COURSE_ID_PATTERN}", + rf"^v1/courses/{settings.COURSE_ID_PATTERN}", CourseView.as_view(), - name="discussion_course" + name="discussion_course", ), re_path( - fr"^v2/courses/{settings.COURSE_ID_PATTERN}", + rf"^v2/courses/{settings.COURSE_ID_PATTERN}", CourseViewV2.as_view(), - name="discussion_course_v2" + name="discussion_course_v2", + ), + re_path( + r"^v1/accounts/retire_forum/?$", + RetireUserView.as_view(), + name="retire_discussion_user", + ), + path( + "v1/accounts/replace_username", + ReplaceUsernamesView.as_view(), + name="replace_discussion_username", ), - re_path(r'^v1/accounts/retire_forum/?$', RetireUserView.as_view(), name="retire_discussion_user"), - path('v1/accounts/replace_username', ReplaceUsernamesView.as_view(), name="replace_discussion_username"), re_path( - fr"^v1/course_topics/{settings.COURSE_ID_PATTERN}", + rf"^v1/course_topics/{settings.COURSE_ID_PATTERN}", CourseTopicsView.as_view(), - name="course_topics" + name="course_topics", ), re_path( - fr"^v2/course_topics/{settings.COURSE_ID_PATTERN}", + rf"^v2/course_topics/{settings.COURSE_ID_PATTERN}", CourseTopicsViewV2.as_view(), - name="course_topics_v2" + name="course_topics_v2", ), re_path( - fr"^v3/course_topics/{settings.COURSE_ID_PATTERN}", + rf"^v3/course_topics/{settings.COURSE_ID_PATTERN}", CourseTopicsViewV3.as_view(), - name="course_topics_v3" + name="course_topics_v3", ), re_path( - fr"^v1/bulk_delete_user_posts/{settings.COURSE_ID_PATTERN}", + rf"^v1/bulk_delete_user_posts/{settings.COURSE_ID_PATTERN}", BulkDeleteUserPosts.as_view(), - name="bulk_delete_user_posts" + name="bulk_delete_user_posts", + ), + re_path( + rf"^v1/bulk_restore_user_posts/{settings.COURSE_ID_PATTERN}", + BulkRestoreUserPosts.as_view(), + name="bulk_restore_user_posts", + ), + path("v1/restore_content", RestoreContent.as_view(), name="restore_content"), + re_path( + rf"^v1/deleted_content/{settings.COURSE_ID_PATTERN}", + DeletedContentView.as_view(), + name="deleted_content", ), - path('v1/', include(ROUTER.urls)), + path("v1/", include(ROUTER.urls)), ] diff --git a/lms/djangoapps/discussion/rest_api/utils.py b/lms/djangoapps/discussion/rest_api/utils.py index 0f02a0dcdcf2..8914527f1b6a 100644 --- a/lms/djangoapps/discussion/rest_api/utils.py +++ b/lms/djangoapps/discussion/rest_api/utils.py @@ -15,6 +15,7 @@ from common.djangoapps.student.roles import CourseInstructorRole, CourseStaffRole from common.djangoapps.student.models import CourseAccessRole +from completion.models import BlockCompletion from openedx.core.djangoapps.django_comment_common.comment_client.thread import Thread from lms.djangoapps.discussion.config.settings import ENABLE_CAPTCHA_IN_DISCUSSION @@ -496,3 +497,86 @@ def get_captcha_site_key_by_platform(platform: str) -> str | None: Get reCAPTCHA site key based on the platform. """ return settings.RECAPTCHA_SITE_KEYS.get(platform, None) + + +def _is_privileged_user(user, course_id): + """ + Check if a user has privileged roles (staff, moderator, TA, etc.) in the course. + + This helper function checks both forum roles and course access roles to determine + if a user should be considered privileged. + + Args: + user: User object to check + course_id: Course key to check roles in + + Returns: + bool: True if user has any privileged role, False otherwise + """ + # Check forum-specific privileged roles + user_roles = get_user_role_names(user, course_id) + privileged_roles = { + FORUM_ROLE_ADMINISTRATOR, + FORUM_ROLE_MODERATOR, + FORUM_ROLE_COMMUNITY_TA, + FORUM_ROLE_GROUP_MODERATOR + } + + if any(role in privileged_roles for role in user_roles): + return True + + # Check for staff roles using CourseAccessRole + # Include limited_staff for consistency with is_only_student check + return CourseAccessRole.objects.filter( + user=user, + course_id=course_id, + role__in=['instructor', 'staff', 'limited_staff'] + ).exists() + + +def _check_user_engagement(user, course_id): + """ + Returns True if the user shows meaningful engagement: + - Completed ≥ 2 blocks, or + - Completed at least 1 video or 1 problem. + """ + try: + completed = BlockCompletion.objects.filter( + user=user, context_key=course_id, completion=1.0 + ) + return ( + completed.count() >= 2 + or completed.filter(block_type__in=["video", "problem"]).exists() + ) + except (AttributeError, TypeError, ValueError): + return False + + +def get_user_learner_status(user, course_id): + """ + Determine a user's learner status in the given course. + + Possible return values: + - "anonymous" → User not logged in + - "staff" → Staff/moderator/TA + - "new" → Enrolled but no engagement + - "regular" → Enrolled and has engaged with course content + + Args: + user (User): Django user object + course_id (CourseKey): Course key to check engagement in + + Returns: + str: One of ["anonymous", "staff", "new", "regular"] + """ + # Anonymous user + if not user or not user.is_authenticated: + return "anonymous" + + # Privileged user (staff/moderator/TA) + if _is_privileged_user(user, course_id): + return "staff" + + # Engagement-based learner type + has_engagement = _check_user_engagement(user, course_id) + return "regular" if has_engagement else "new" diff --git a/lms/djangoapps/discussion/rest_api/views.py b/lms/djangoapps/discussion/rest_api/views.py index ba9818124e08..1f21c81e6ead 100644 --- a/lms/djangoapps/discussion/rest_api/views.py +++ b/lms/djangoapps/discussion/rest_api/views.py @@ -1,17 +1,20 @@ """ Discussion API views """ + import logging import uuid import edx_api_doc_tools as apidocs - +from opaque_keys import InvalidKeyError from django.contrib.auth import get_user_model from django.core.exceptions import BadRequest, ValidationError from django.shortcuts import get_object_or_404 from drf_yasg import openapi from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication -from edx_rest_framework_extensions.auth.session.authentication import SessionAuthenticationAllowInactiveUser +from edx_rest_framework_extensions.auth.session.authentication import ( + SessionAuthenticationAllowInactiveUser, +) from opaque_keys.edx.keys import CourseKey from rest_framework import permissions, status from rest_framework.authentication import SessionAuthentication @@ -21,31 +24,49 @@ from rest_framework.views import APIView from rest_framework.viewsets import ViewSet -from xmodule.modulestore.django import modulestore - from common.djangoapps.student.models import CourseEnrollment from common.djangoapps.util.file import store_uploaded_file from lms.djangoapps.course_api.blocks.api import get_blocks from lms.djangoapps.course_goals.models import UserActivity -from lms.djangoapps.discussion.rate_limit import is_content_creation_rate_limited -from lms.djangoapps.discussion.rest_api.permissions import IsAllowedToBulkDelete -from lms.djangoapps.discussion.rest_api.tasks import delete_course_post_for_user -from lms.djangoapps.discussion.toggles import ONLY_VERIFIED_USERS_CAN_POST from lms.djangoapps.discussion.django_comment_client import settings as cc_settings -from lms.djangoapps.discussion.django_comment_client.utils import get_group_id_for_comments_service +from lms.djangoapps.discussion.django_comment_client.utils import ( + get_group_id_for_comments_service, +) +from lms.djangoapps.discussion.rate_limit import is_content_creation_rate_limited +from lms.djangoapps.discussion.rest_api.permissions import IsAllowedToBulkDelete, IsAllowedToRestore +from lms.djangoapps.discussion.rest_api.tasks import ( + delete_course_post_for_user, + restore_course_post_for_user, +) +from lms.djangoapps.discussion.toggles import ONLY_VERIFIED_USERS_CAN_POST, ENABLE_DISCUSSION_BAN from lms.djangoapps.instructor.access import update_forum_role -from openedx.core.djangoapps.discussions.config.waffle import ENABLE_NEW_STRUCTURE_DISCUSSIONS -from openedx.core.djangoapps.discussions.models import DiscussionsConfiguration, Provider +from openedx.core.djangoapps.discussions.config.waffle import ( + ENABLE_NEW_STRUCTURE_DISCUSSIONS, +) +from openedx.core.djangoapps.discussions.models import ( + DiscussionsConfiguration, + Provider, +) from openedx.core.djangoapps.discussions.serializers import DiscussionSettingsSerializer from openedx.core.djangoapps.django_comment_common import comment_client -from openedx.core.djangoapps.django_comment_common.models import CourseDiscussionSettings, Role from openedx.core.djangoapps.django_comment_common.comment_client.comment import Comment from openedx.core.djangoapps.django_comment_common.comment_client.thread import Thread -from openedx.core.djangoapps.user_api.accounts.permissions import CanReplaceUsername, CanRetireUser +from openedx.core.djangoapps.django_comment_common.models import ( + CourseDiscussionSettings, + Role, +) +from openedx.core.djangoapps.user_api.accounts.permissions import ( + CanReplaceUsername, + CanRetireUser, +) from openedx.core.djangoapps.user_api.models import UserRetirementStatus -from openedx.core.lib.api.authentication import BearerAuthentication, BearerAuthenticationAllowInactiveUser +from openedx.core.lib.api.authentication import ( + BearerAuthentication, + BearerAuthenticationAllowInactiveUser, +) from openedx.core.lib.api.parsers import MergePatchParser from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin, view_auth_classes +from xmodule.modulestore.django import modulestore from ..rest_api.api import ( create_comment, @@ -57,10 +78,10 @@ get_course_discussion_user_stats, get_course_topics, get_course_topics_v2, + get_learner_active_thread_list, get_response_comments, get_thread, get_thread_list, - get_learner_active_thread_list, get_user_comments, get_v2_course_topics_as_v1, update_comment, @@ -88,10 +109,10 @@ from .utils import ( create_blocks_params, create_topics_v3_structure, - is_captcha_enabled, - verify_recaptcha_token, get_course_id_from_thread_id, + is_captcha_enabled, is_only_student, + verify_recaptcha_token, ) log = logging.getLogger(__name__) @@ -107,14 +128,16 @@ class CourseView(DeveloperErrorViewMixin, APIView): @apidocs.schema( parameters=[ - apidocs.string_parameter("course_id", apidocs.ParameterLocation.PATH, description="Course ID") + apidocs.string_parameter( + "course_id", apidocs.ParameterLocation.PATH, description="Course ID" + ) ], responses={ 200: CourseMetadataSerailizer(read_only=True, required=False), 401: "The requester is not authenticated.", 403: "The requester cannot access the specified course.", 404: "The requested course does not exist.", - } + }, ) def get(self, request, course_id): """ @@ -126,7 +149,9 @@ def get(self, request, course_id): """ course_key = CourseKey.from_string(course_id) # TODO: which class is right? # Record user activity for tracking progress towards a user's course goals (for mobile app) - UserActivity.record_user_activity(request.user, course_key, request=request, only_if_mobile_app=True) + UserActivity.record_user_activity( + request.user, course_key, request=request, only_if_mobile_app=True + ) return Response(get_course(request, course_key)) @@ -138,14 +163,16 @@ class CourseViewV2(DeveloperErrorViewMixin, APIView): @apidocs.schema( parameters=[ - apidocs.string_parameter("course_id", apidocs.ParameterLocation.PATH, description="Course ID") + apidocs.string_parameter( + "course_id", apidocs.ParameterLocation.PATH, description="Course ID" + ) ], responses={ 200: CourseMetadataSerailizer(read_only=True, required=False), 401: "The requester is not authenticated.", 403: "The requester cannot access the specified course.", 404: "The requested course does not exist.", - } + }, ) def get(self, request, course_id): """ @@ -156,7 +183,9 @@ def get(self, request, course_id): """ course_key = CourseKey.from_string(course_id) # Record user activity for tracking progress towards a user's course goals (for mobile app) - UserActivity.record_user_activity(request.user, course_key, request=request, only_if_mobile_app=True) + UserActivity.record_user_activity( + request.user, course_key, request=request, only_if_mobile_app=True + ) return Response(get_course(request, course_key, False)) @@ -221,14 +250,14 @@ def get(self, request, course_key_string): form_query_string = CourseActivityStatsForm(request.query_params) if not form_query_string.is_valid(): raise ValidationError(form_query_string.errors) - order_by = form_query_string.cleaned_data.get('order_by', None) + order_by = form_query_string.cleaned_data.get("order_by", None) order_by = UserOrdering(order_by) if order_by else None - username_search_string = form_query_string.cleaned_data.get('username', None) + username_search_string = form_query_string.cleaned_data.get("username", None) data = get_course_discussion_user_stats( request, course_key_string, - form_query_string.cleaned_data['page'], - form_query_string.cleaned_data['page_size'], + form_query_string.cleaned_data["page"], + form_query_string.cleaned_data["page_size"], order_by, username_search_string, ) @@ -268,19 +297,17 @@ def get(self, request, course_id): Implements the GET method as described in the class docstring. """ course_key = CourseKey.from_string(course_id) - topic_ids = self.request.GET.get('topic_id') - topic_ids = set(topic_ids.strip(',').split(',')) if topic_ids else None + topic_ids = self.request.GET.get("topic_id") + topic_ids = set(topic_ids.strip(",").split(",")) if topic_ids else None with modulestore().bulk_operations(course_key): configuration = DiscussionsConfiguration.get(context_key=course_key) provider = configuration.provider_type # This will be removed when mobile app will support new topic structure - new_structure_enabled = ENABLE_NEW_STRUCTURE_DISCUSSIONS.is_enabled(course_key) + new_structure_enabled = ENABLE_NEW_STRUCTURE_DISCUSSIONS.is_enabled( + course_key + ) if provider == Provider.OPEN_EDX and new_structure_enabled: - response = get_v2_course_topics_as_v1( - request, - course_key, - topic_ids - ) + response = get_v2_course_topics_as_v1(request, course_key, topic_ids) else: response = get_course_topics( request, @@ -288,7 +315,9 @@ def get(self, request, course_id): topic_ids, ) # Record user activity for tracking progress towards a user's course goals (for mobile app) - UserActivity.record_user_activity(request.user, course_key, request=request, only_if_mobile_app=True) + UserActivity.record_user_activity( + request.user, course_key, request=request, only_if_mobile_app=True + ) return Response(response) @@ -304,17 +333,17 @@ class CourseTopicsViewV2(DeveloperErrorViewMixin, APIView): @apidocs.schema( parameters=[ apidocs.string_parameter( - 'course_id', + "course_id", apidocs.ParameterLocation.PATH, description="Course ID", ), apidocs.string_parameter( - 'topic_id', + "topic_id", apidocs.ParameterLocation.QUERY, description="Comma-separated list of topic ids to filter", ), openapi.Parameter( - 'order_by', + "order_by", apidocs.ParameterLocation.QUERY, required=False, type=openapi.TYPE_STRING, @@ -327,7 +356,7 @@ class CourseTopicsViewV2(DeveloperErrorViewMixin, APIView): 401: "The requester is not authenticated.", 403: "The requester cannot access the specified course.", 404: "The requested course does not exist.", - } + }, ) def get(self, request, course_id): """ @@ -348,7 +377,7 @@ def get(self, request, course_id): course_key, request.user, form_query_params.cleaned_data["topic_id"], - form_query_params.cleaned_data["order_by"] + form_query_params.cleaned_data["order_by"], ) return Response(response) @@ -416,17 +445,17 @@ def get(self, request, course_id): blocks_params = create_blocks_params(course_usage_key, request.user) blocks = get_blocks( request, - blocks_params['usage_key'], - blocks_params['user'], - blocks_params['depth'], - blocks_params['nav_depth'], - blocks_params['requested_fields'], - blocks_params['block_counts'], - blocks_params['student_view_data'], - blocks_params['return_type'], - blocks_params['block_types_filter'], + blocks_params["usage_key"], + blocks_params["user"], + blocks_params["depth"], + blocks_params["nav_depth"], + blocks_params["requested_fields"], + blocks_params["block_counts"], + blocks_params["student_view_data"], + blocks_params["return_type"], + blocks_params["block_types_filter"], hide_access_denials=False, - )['blocks'] + )["blocks"] topics = create_topics_v3_structure(blocks, topics) return Response(topics) @@ -627,8 +656,12 @@ class ThreadViewSet(DeveloperErrorViewMixin, ViewSet): No content is returned for a DELETE request """ + lookup_field = "thread_id" - parser_classes = (JSONParser, MergePatchParser,) + parser_classes = ( + JSONParser, + MergePatchParser, + ) def list(self, request): """ @@ -641,7 +674,10 @@ class docstring. # Record user activity for tracking progress towards a user's course goals (for mobile app) UserActivity.record_user_activity( - request.user, form.cleaned_data["course_id"], request=request, only_if_mobile_app=True + request.user, + form.cleaned_data["course_id"], + request=request, + only_if_mobile_app=True, ) return get_thread_list( @@ -660,14 +696,15 @@ class docstring. form.cleaned_data["order_direction"], form.cleaned_data["requested_fields"], form.cleaned_data["count_flagged"], + form.cleaned_data["show_deleted"], ) def retrieve(self, request, thread_id=None): """ Implements the GET method for thread ID """ - requested_fields = request.GET.get('requested_fields') - course_id = request.GET.get('course_id') + requested_fields = request.GET.get("requested_fields") + course_id = request.GET.get("course_id") return Response(get_thread(request, thread_id, requested_fields, course_id)) def create(self, request): @@ -681,21 +718,28 @@ class docstring. course_key = CourseKey.from_string(course_key_str) if is_content_creation_rate_limited(request, course_key=course_key): - return Response("Too many requests", status=status.HTTP_429_TOO_MANY_REQUESTS) + return Response( + "Too many requests", status=status.HTTP_429_TOO_MANY_REQUESTS + ) if is_captcha_enabled(course_key) and is_only_student(course_key, request.user): - captcha_token = request.data.get('captcha_token') + captcha_token = request.data.get("captcha_token") if not captcha_token: - raise ValidationError({'captcha_token': 'This field is required.'}) + raise ValidationError({"captcha_token": "This field is required."}) if not verify_recaptcha_token(captcha_token): - return Response({'error': 'CAPTCHA verification failed.'}, status=400) - - if ONLY_VERIFIED_USERS_CAN_POST.is_enabled(course_key) and not request.user.is_active: - raise ValidationError({"detail": "Only verified users can post in discussions."}) + return Response({"error": "CAPTCHA verification failed."}, status=400) + + if ( + ONLY_VERIFIED_USERS_CAN_POST.is_enabled(course_key) + and not request.user.is_active + ): + raise ValidationError( + {"detail": "Only verified users can post in discussions."} + ) data = request.data.copy() - data.pop('captcha_token', None) + data.pop("captcha_token", None) return Response(create_thread(request, data)) def partial_update(self, request, thread_id): @@ -762,24 +806,27 @@ def get(self, request, course_id=None): Implements the GET method as described in the class docstring. """ course_key = CourseKey.from_string(course_id) - page_num = request.GET.get('page', 1) - threads_per_page = request.GET.get('page_size', 10) - count_flagged = request.GET.get('count_flagged', False) - thread_type = request.GET.get('thread_type') - order_by = request.GET.get('order_by') + page_num = request.GET.get("page", 1) + threads_per_page = request.GET.get("page_size", 10) + count_flagged = request.GET.get("count_flagged", False) + thread_type = request.GET.get("thread_type") + order_by = request.GET.get("order_by") order_by_mapping = { "last_activity_at": "activity", "comment_count": "comments", - "vote_count": "votes" + "vote_count": "votes", } - order_by = order_by_mapping.get(order_by, 'activity') - post_status = request.GET.get('status', None) + order_by = order_by_mapping.get(order_by, "activity") + post_status = request.GET.get("status", None) + show_deleted = request.GET.get("show_deleted", "false").lower() == "true" discussion_id = None - username = request.GET.get('username', None) + username = request.GET.get("username", None) user = get_object_or_404(User, username=username) group_id = None try: - group_id = get_group_id_for_comments_service(request, course_key, discussion_id) + group_id = get_group_id_for_comments_service( + request, course_key, discussion_id + ) except ValueError: pass @@ -792,14 +839,17 @@ def get(self, request, course_id=None): "count_flagged": count_flagged, "thread_type": thread_type, "sort_key": order_by, + "show_deleted": show_deleted, } if post_status: - if post_status not in ['flagged', 'unanswered', 'unread', 'unresponded']: - raise ValidationError({ - "status": [ - f"Invalid value. '{post_status}' must be 'flagged', 'unanswered', 'unread' or 'unresponded" - ] - }) + if post_status not in ["flagged", "unanswered", "unread", "unresponded"]: + raise ValidationError( + { + "status": [ + f"Invalid value. '{post_status}' must be 'flagged', 'unanswered', 'unread' or 'unresponded" + ] + } + ) query_params[post_status] = True return get_learner_active_thread_list(request, course_key, query_params) @@ -968,8 +1018,12 @@ class CommentViewSet(DeveloperErrorViewMixin, ViewSet): No content is returned for a DELETE request """ + lookup_field = "comment_id" - parser_classes = (JSONParser, MergePatchParser,) + parser_classes = ( + JSONParser, + MergePatchParser, + ) def list(self, request): """ @@ -1010,7 +1064,8 @@ def list_by_thread(self, request): form.cleaned_data["page_size"], form.cleaned_data["flagged"], form.cleaned_data["requested_fields"], - form.cleaned_data["merge_question_type_responses"] + form.cleaned_data["merge_question_type_responses"], + form.cleaned_data["show_deleted"], ) def list_by_user(self, request): @@ -1057,21 +1112,28 @@ class docstring. course_key = CourseKey.from_string(course_key_str) if is_content_creation_rate_limited(request, course_key=course_key): - return Response("Too many requests", status=status.HTTP_429_TOO_MANY_REQUESTS) + return Response( + "Too many requests", status=status.HTTP_429_TOO_MANY_REQUESTS + ) if is_captcha_enabled(course_key) and is_only_student(course_key, request.user): - captcha_token = request.data.get('captcha_token') + captcha_token = request.data.get("captcha_token") if not captcha_token: - raise ValidationError({'captcha_token': 'This field is required.'}) + raise ValidationError({"captcha_token": "This field is required."}) if not verify_recaptcha_token(captcha_token): - return Response({'error': 'CAPTCHA verification failed.'}, status=400) - - if ONLY_VERIFIED_USERS_CAN_POST.is_enabled(course_key) and not request.user.is_active: - raise ValidationError({"detail": "Only verified users can post in discussions."}) + return Response({"error": "CAPTCHA verification failed."}, status=400) + + if ( + ONLY_VERIFIED_USERS_CAN_POST.is_enabled(course_key) + and not request.user.is_active + ): + raise ValidationError( + {"detail": "Only verified users can post in discussions."} + ) data = request.data.copy() - data.pop('captcha_token', None) + data.pop("captcha_token", None) return Response(create_comment(request, data)) def destroy(self, request, comment_id): @@ -1147,8 +1209,11 @@ def post(self, request, course_id): unique_file_name = f"{course_id}/{thread_key}/{uuid.uuid4()}" try: file_storage, stored_file_name = store_uploaded_file( - request, "uploaded_file", cc_settings.ALLOWED_UPLOAD_FILE_TYPES, - unique_file_name, max_file_size=cc_settings.MAX_UPLOAD_FILE_SIZE, + request, + "uploaded_file", + cc_settings.ALLOWED_UPLOAD_FILE_TYPES, + unique_file_name, + max_file_size=cc_settings.MAX_UPLOAD_FILE_SIZE, ) except ValueError as err: raise BadRequest("no `uploaded_file` was provided") from err @@ -1189,10 +1254,12 @@ def post(self, request): """ Implements the retirement endpoint. """ - username = request.data['username'] + username = request.data["username"] try: - retirement = UserRetirementStatus.get_retirement_for_retirement_action(username) + retirement = UserRetirementStatus.get_retirement_for_retirement_action( + username + ) cc_user = comment_client.User.from_django_user(retirement.user) # Send the retired username to the forums service, as the service cannot generate @@ -1247,7 +1314,9 @@ def post(self, request): for username_pair in username_mappings: current_username = list(username_pair.keys())[0] new_username = list(username_pair.values())[0] - successfully_replaced = self._replace_username(current_username, new_username) + successfully_replaced = self._replace_username( + current_username, new_username + ) if successfully_replaced: successful_replacements.append({current_username: new_username}) else: @@ -1257,8 +1326,8 @@ def post(self, request): status=status.HTTP_200_OK, data={ "successful_replacements": successful_replacements, - "failed_replacements": failed_replacements - } + "failed_replacements": failed_replacements, + }, ) def _replace_username(self, current_username, new_username): @@ -1304,7 +1373,7 @@ def _replace_username(self, current_username, new_username): return True def _has_valid_schema(self, post_data): - """ Verifies the data is a list of objects with a single key:value pair """ + """Verifies the data is a list of objects with a single key:value pair""" if not isinstance(post_data, list): return False for obj in post_data: @@ -1364,12 +1433,16 @@ class CourseDiscussionSettingsAPIView(DeveloperErrorViewMixin, APIView): * available_division_schemes: A list of available division schemes for the course. """ + authentication_classes = ( JwtAuthentication, BearerAuthenticationAllowInactiveUser, SessionAuthenticationAllowInactiveUser, ) - parser_classes = (JSONParser, MergePatchParser,) + parser_classes = ( + JSONParser, + MergePatchParser, + ) permission_classes = (permissions.IsAuthenticated, IsStaffOrAdmin) def _get_request_kwargs(self, course_id): @@ -1385,14 +1458,14 @@ def get(self, request, course_id): if not form.is_valid(): raise ValidationError(form.errors) - course_key = form.cleaned_data['course_key'] - course = form.cleaned_data['course'] + course_key = form.cleaned_data["course_key"] + course = form.cleaned_data["course"] discussion_settings = CourseDiscussionSettings.get(course_key) serializer = DiscussionSettingsSerializer( discussion_settings, context={ - 'course': course, - 'settings': discussion_settings, + "course": course, + "settings": discussion_settings, }, partial=True, ) @@ -1411,15 +1484,15 @@ def patch(self, request, course_id): if not form.is_valid(): raise ValidationError(form.errors) - course = form.cleaned_data['course'] - course_key = form.cleaned_data['course_key'] + course = form.cleaned_data["course"] + course_key = form.cleaned_data["course_key"] discussion_settings = CourseDiscussionSettings.get(course_key) serializer = DiscussionSettingsSerializer( discussion_settings, context={ - 'course': course, - 'settings': discussion_settings, + "course": course, + "settings": discussion_settings, }, data=request.data, partial=True, @@ -1488,6 +1561,7 @@ class CourseDiscussionRolesAPIView(DeveloperErrorViewMixin, APIView): * division_scheme: The division scheme used by the course. """ + authentication_classes = ( JwtAuthentication, BearerAuthenticationAllowInactiveUser, @@ -1508,11 +1582,13 @@ def get(self, request, course_id, rolename): if not form.is_valid(): raise ValidationError(form.errors) - course_id = form.cleaned_data['course_key'] - role = form.cleaned_data['role'] + course_id = form.cleaned_data["course_key"] + role = form.cleaned_data["role"] - data = {'course_id': course_id, 'users': role.users.all()} - context = {'course_discussion_settings': CourseDiscussionSettings.get(course_id)} + data = {"course_id": course_id, "users": role.users.all()} + context = { + "course_discussion_settings": CourseDiscussionSettings.get(course_id) + } serializer = DiscussionRolesListSerializer(data, context=context) return Response(serializer.data) @@ -1526,23 +1602,25 @@ def post(self, request, course_id, rolename): if not form.is_valid(): raise ValidationError(form.errors) - course_id = form.cleaned_data['course_key'] - rolename = form.cleaned_data['rolename'] + course_id = form.cleaned_data["course_key"] + rolename = form.cleaned_data["rolename"] serializer = DiscussionRolesSerializer(data=request.data) if not serializer.is_valid(): raise ValidationError(serializer.errors) - action = serializer.validated_data['action'] - user = serializer.validated_data['user'] + action = serializer.validated_data["action"] + user = serializer.validated_data["user"] try: update_forum_role(course_id, user, rolename, action) except Role.DoesNotExist as err: raise ValidationError(f"Role '{rolename}' does not exist") from err - role = form.cleaned_data['role'] - data = {'course_id': course_id, 'users': role.users.all()} - context = {'course_discussion_settings': CourseDiscussionSettings.get(course_id)} + role = form.cleaned_data["role"] + data = {"course_id": course_id, "users": role.users.all()} + context = { + "course_discussion_settings": CourseDiscussionSettings.get(course_id) + } serializer = DiscussionRolesListSerializer(data, context=context) return Response(serializer.data) @@ -1566,7 +1644,9 @@ class BulkDeleteUserPosts(DeveloperErrorViewMixin, APIView): """ authentication_classes = ( - JwtAuthentication, BearerAuthentication, SessionAuthentication, + JwtAuthentication, + BearerAuthentication, + SessionAuthentication, ) permission_classes = (permissions.IsAuthenticated, IsAllowedToBulkDelete) @@ -1587,23 +1667,26 @@ def post(self, request, course_id): course_ids = [course_id] if course_or_org == "org": org_id = CourseKey.from_string(course_id).org - enrollments = CourseEnrollment.objects.filter(user=request.user).values_list('course_id', flat=True) - course_ids.extend([ - str(c_id) - for c_id in enrollments - if c_id.org == org_id - ]) + enrollments = CourseEnrollment.objects.filter( + user=user + ).values_list("course_id", flat=True) + course_ids.extend([str(c_id) for c_id in enrollments if c_id.org == org_id]) course_ids = list(set(course_ids)) log.info(f"<> {username} enrolled in {enrollments}") - log.info(f"<> Posts for {username} in {course_ids} - for {course_or_org} {course_id}") + log.info( + f"<> Posts for {username} in {course_ids} - for {course_or_org} {course_id}" + ) comment_count = Comment.get_user_comment_count(user.id, course_ids) thread_count = Thread.get_user_threads_count(user.id, course_ids) - log.info(f"<> {username} in {course_ids} - Count thread {thread_count}, comment {comment_count}") + log.info( + f"<> {username} in {course_ids} - Count thread {thread_count}, comment {comment_count}" + ) if execute_task: event_data = { "triggered_by": request.user.username, + "triggered_by_user_id": str(request.user.id), "username": username, "course_or_org": course_or_org, "course_key": course_id, @@ -1613,5 +1696,1271 @@ def post(self, request, course_id): ) return Response( {"comment_count": comment_count, "thread_count": thread_count}, - status=status.HTTP_202_ACCEPTED + status=status.HTTP_202_ACCEPTED, + ) + + +class RestoreContent(DeveloperErrorViewMixin, APIView): + """ + **Use Cases** + A privileged user that can restore individual soft-deleted threads, comments, or responses. + + **Example Requests**: + POST /api/discussion/v1/restore_content + Request Body: + { + "content_type": "thread", // "thread", "comment", or "response" + "content_id": "thread_id_or_comment_id", + "course_id": "course-v1:edX+DemoX+Demo_Course" + } + + **Example Response**: + {"success": true, "message": "Content restored successfully"} + """ + + authentication_classes = ( + JwtAuthentication, + BearerAuthentication, + SessionAuthentication, + ) + permission_classes = (permissions.IsAuthenticated, IsAllowedToRestore) + + def post(self, request): + """ + Implements the restore individual content endpoint. + """ + content_type = request.data.get("content_type") + content_id = request.data.get("content_id") + course_id = request.data.get("course_id") + + if not all([content_type, content_id, course_id]): + raise BadRequest("content_type, content_id, and course_id are required.") + + if content_type not in ["thread", "comment", "response"]: + raise BadRequest("content_type must be 'thread', 'comment', or 'response'.") + + restored_by_user_id = str(request.user.id) + + try: + if content_type == "thread": + success = Thread.restore_thread( + content_id, course_id=course_id, restored_by=restored_by_user_id + ) + else: # comment or response (both are comments in the backend) + success = Comment.restore_comment( + content_id, course_id=course_id, restored_by=restored_by_user_id + ) + + if success: + return Response( + { + "success": True, + "message": f"{content_type.capitalize()} restored successfully", + }, + status=status.HTTP_200_OK, + ) + else: + return Response( + { + "success": False, + "message": f"{content_type.capitalize()} not found or already restored", + }, + status=status.HTTP_404_NOT_FOUND, + ) + except Exception as e: # pylint: disable=broad-exception-caught + log.error("Error restoring %s %s: %s", content_type, content_id, str(e)) + return Response( + { + "success": False, + "message": f"Error restoring {content_type}: {str(e)}", + }, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +class BulkRestoreUserPosts(DeveloperErrorViewMixin, APIView): + """ + **Use Cases** + A privileged user that can restore all soft-deleted posts and comments made by a user. + It returns expected number of comments and threads that will be restored + + **Example Requests**: + POST /api/discussion/v1/bulk_restore_user_posts/{course_id} + Query Parameters: + username: The username of the user whose posts are to be restored + course_id: Course id for which posts are to be restored + execute: If True, runs restoration task + course_or_org: If 'course', restores posts in the course, if 'org', restores posts in all courses of the org + + **Example Response**: + {"comment_count": 5, "thread_count": 3} + """ + + authentication_classes = ( + JwtAuthentication, + BearerAuthentication, + SessionAuthentication, + ) + permission_classes = (permissions.IsAuthenticated, IsAllowedToBulkDelete) + + def post(self, request, course_id): + """ + Implements the restore user posts endpoint. + """ + username = request.GET.get("username", None) + execute_task = request.GET.get("execute", "false").lower() == "true" + if (not username) or (not course_id): + raise BadRequest("username and course_id are required.") + course_or_org = request.GET.get("course_or_org", "course") + if course_or_org not in ["course", "org"]: + raise BadRequest("course_or_org must be either 'course' or 'org'.") + + user = get_object_or_404(User, username=username) + course_ids = [course_id] + if course_or_org == "org": + org_id = CourseKey.from_string(course_id).org + enrollments = CourseEnrollment.objects.filter( + user=user + ).values_list("course_id", flat=True) + course_ids.extend([str(c_id) for c_id in enrollments if c_id.org == org_id]) + course_ids = list(set(course_ids)) + log.info("<> %s enrolled in %s", username, enrollments) + log.info( + "<> Posts for %s in %s - for %s %s", + username, + course_ids, + course_or_org, + course_id, + ) + + comment_count = Comment.get_user_deleted_comment_count(user.id, course_ids) + thread_count = Thread.get_user_deleted_threads_count(user.id, course_ids) + log.info( + "<> %s in %s - Count thread %s, comment %s", + username, + course_ids, + thread_count, + comment_count, + ) + + if execute_task: + event_data = { + "triggered_by": request.user.username, + "triggered_by_user_id": str(request.user.id), + "username": username, + "course_or_org": course_or_org, + "course_key": course_id, + } + restore_course_post_for_user.apply_async( + args=(user.id, username, course_ids, event_data), + ) + return Response( + {"comment_count": comment_count, "thread_count": thread_count}, + status=status.HTTP_202_ACCEPTED, + ) + + +class DeletedContentView(DeveloperErrorViewMixin, APIView): + """ + **Use Cases** + Retrieve all deleted content (threads, comments, responses) for a course. + This endpoint allows privileged users to fetch deleted discussion content. + + **Example Requests**: + GET /api/discussion/v1/deleted_content/course-v1:edX+DemoX+Demo_Course + GET /api/discussion/v1/deleted_content/course-v1:edX+DemoX+Demo_Course?content_type=thread + GET /api/discussion/v1/deleted_content/course-v1:edX+DemoX+Demo_Course?page=1&per_page=20 + + **Example Response**: + { + "results": [ + { + "id": "thread_id", + "type": "thread", + "title": "Deleted Thread Title", + "body": "Thread content...", + "course_id": "course-v1:edX+DemoX+Demo_Course", + "author_id": "user_123", + "deleted_at": "2023-11-19T10:30:00Z", + "deleted_by": "moderator_456" + } + ], + "pagination": { + "page": 1, + "per_page": 20, + "total_count": 50, + "num_pages": 3 + } + } + """ + + authentication_classes = ( + JwtAuthentication, + BearerAuthentication, + SessionAuthentication, + ) + permission_classes = (permissions.IsAuthenticated, IsAllowedToBulkDelete) + + def get(self, request, course_id): + """ + Retrieve all deleted content for a course. + """ + try: + course_key = CourseKey.from_string(course_id) + except Exception as e: + raise BadRequest("Invalid course_id") from e + + # Get query parameters + content_type = request.GET.get( + "content_type", None + ) # 'thread', 'comment', or None for all + page = int(request.GET.get("page", 1)) + per_page = int(request.GET.get("per_page", 20)) + author_id = request.GET.get("author_id", None) + + # Validate parameters + if content_type and content_type not in ["thread", "comment"]: + raise BadRequest("content_type must be 'thread' or 'comment'") + + per_page = min(per_page, 100) # Limit to prevent excessive load + + try: + # Import here to avoid circular imports + from lms.djangoapps.discussion.rest_api.api import ( + get_deleted_content_for_course, + ) + + results = get_deleted_content_for_course( + request=request, + course_id=str(course_key), + content_type=content_type, + page=page, + per_page=per_page, + author_id=author_id, + ) + + return Response(results, status=status.HTTP_200_OK) + + except Exception as e: # pylint: disable=broad-exception-caught + logging.exception( + "Error retrieving deleted content for course %s: %s", course_id, e + ) + return Response( + {"error": "Failed to retrieve deleted content"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +class DiscussionModerationViewSet(DeveloperErrorViewMixin, ViewSet): + """ + **Use Cases** + + Perform bulk moderation actions on discussion posts and manage user bans. + + **Example Requests** + + POST /api/discussion/v1/moderation/bulk-delete-ban/ + GET /api/discussion/v1/moderation/banned-users/course-v1:edX+DemoX+Demo + POST /api/discussion/v1/moderation/123/unban/ + """ + + authentication_classes = ( + JwtAuthentication, BearerAuthentication, SessionAuthentication, + ) + permission_classes = (permissions.IsAuthenticated, IsAllowedToBulkDelete) + + def get_permissions(self): + """ + Return permission instances for the view. + + For unban_user, unban_user_by_id, and banned_users actions, we only need IsAuthenticated + because we check course-specific permissions inside the action method after retrieving the ban. + For ban_user, we check permissions inside the action based on scope. + """ + if self.action in ['unban_user', 'unban_user_by_id', 'banned_users', 'ban_user']: + return [permissions.IsAuthenticated()] + return super().get_permissions() + + @apidocs.schema( + body=openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['course_id'], + properties={ + 'user_id': openapi.Schema( + type=openapi.TYPE_INTEGER, + description='ID of the user to ban (required if username is not provided)' + ), + 'username': openapi.Schema( + type=openapi.TYPE_STRING, + description='Username of the user to ban (required if user_id is not provided)' + ), + 'course_id': openapi.Schema( + type=openapi.TYPE_STRING, + description='Course ID (e.g., course-v1:edX+DemoX+Demo_Course)' + ), + 'scope': openapi.Schema( + type=openapi.TYPE_STRING, + description='Scope of ban: "course" or "organization"', + enum=['course', 'organization'], + default='course' + ), + 'reason': openapi.Schema( + type=openapi.TYPE_STRING, + description='Reason for the ban (optional)', + max_length=1000 + ), + }, + ), + responses={ + 201: openapi.Response( + description='User banned successfully', + schema=openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'status': openapi.Schema(type=openapi.TYPE_STRING, example='success'), + 'message': openapi.Schema(type=openapi.TYPE_STRING), + 'ban_id': openapi.Schema(type=openapi.TYPE_INTEGER), + 'user_id': openapi.Schema(type=openapi.TYPE_INTEGER), + 'username': openapi.Schema(type=openapi.TYPE_STRING), + 'scope': openapi.Schema(type=openapi.TYPE_STRING), + 'course_id': openapi.Schema(type=openapi.TYPE_STRING), + }, + ), + ), + 400: 'Invalid request data or user already banned.', + 401: 'The requester is not authenticated.', + 403: 'The requester does not have permission to ban users.', + 404: 'The specified user does not exist.', + }, + ) + def _validate_ban_request_and_get_user(self, request, serializer_data): + """ + Validate ban request and retrieve target user. + + Returns tuple of (user, course_key, ban_scope, reason) or Response object on error. + """ + from lms.djangoapps.discussion.rest_api.utils import ( + _is_privileged_user, ) + + user_id = serializer_data.get('user_id') + lookup_username = serializer_data.get('lookup_username') + course_id_str = serializer_data['course_id'] + ban_scope = serializer_data.get('scope', 'course') + reason = serializer_data.get('reason', '').strip() + + try: + course_key = CourseKey.from_string(course_id_str) + except InvalidKeyError: + return Response( + {'error': f'Invalid course_id: {course_id_str}'}, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + if user_id: + user = User.objects.get(id=user_id) + elif lookup_username: + user = User.objects.get(username=lookup_username) + else: + return Response( + {'error': 'Either user_id or username must be provided'}, + status=status.HTTP_400_BAD_REQUEST + ) + except User.DoesNotExist: + identifier = user_id if user_id else lookup_username + return Response( + {'error': f'User {identifier} does not exist'}, + status=status.HTTP_404_NOT_FOUND + ) + + # Check if user is staff/privileged - they shouldn't be banned + if _is_privileged_user(user, course_key): + return Response( + { + 'error': ( + f'Cannot ban staff or privileged users. User {user.username} ' + f'has elevated permissions in this course.' + ) + }, + status=status.HTTP_400_BAD_REQUEST + ) + + return user, course_key, ban_scope, reason + + def _check_ban_permissions(self, request, ban_scope, course_key): + """ + Check if user has permission to ban at the specified scope. + + Returns Response object on permission denied, None if permitted. + """ + from lms.djangoapps.discussion.rest_api.permissions import can_take_action_on_spam + from common.djangoapps.student.roles import GlobalStaff + + if ban_scope == 'course': + if not can_take_action_on_spam(request.user, course_key): + return Response( + {'error': 'You do not have permission to ban users in this course'}, + status=status.HTTP_403_FORBIDDEN + ) + else: + if not (GlobalStaff().has_user(request.user) or request.user.is_staff): + return Response( + {'error': 'Organization-level bans require global staff permissions'}, + status=status.HTTP_403_FORBIDDEN + ) + + if not ENABLE_DISCUSSION_BAN.is_enabled(course_key): + return Response( + {'error': 'Discussion ban feature is not enabled for this course'}, + status=status.HTTP_403_FORBIDDEN + ) + + return None + + def _get_or_create_ban(self, user, course_key, ban_scope, reason, request): + """ + Get existing ban or create new one. + + Returns tuple of (ban, action_type, message) or Response object on error. + """ + from forum import api as forum_api + + # Check if already banned + if forum_api.is_user_banned(user, course_key, check_org=(ban_scope == 'organization')): + existing_ban = forum_api.get_ban( + user=user, + course_id=course_key, + scope=ban_scope + ) + if existing_ban and existing_ban['is_active']: + return Response( + { + 'error': f'User {user.username} is already banned at {ban_scope} level', + 'ban_id': existing_ban['id'] + }, + status=status.HTTP_400_BAD_REQUEST + ) + + # Use forum API to ban user + ban_result = forum_api.ban_user( + user=user, + banned_by=request.user, + course_id=course_key, + scope=ban_scope, + reason=reason + ) + + # Determine action type and message + if ban_result.get('reactivated'): + action_type = 'ban_reactivate' + message = f'User {user.username} ban reactivated at {ban_scope} level' + else: + action_type = 'ban_user' + message = f'User {user.username} banned at {ban_scope} level' + + return ban_result, action_type, message + + def ban_user(self, request): + """ + Ban a user from discussions without deleting posts. + + **Use Cases** + + * Ban user directly from UI moderation interface + * Prevent future posts without removing existing content + * Apply preventive bans based on behavior patterns + + **Example Requests** + + POST /api/discussion/v1/moderation/ban-user/ + + Course-level ban: + ```json + { + "user_id": 12345, + "course_id": "course-v1:HarvardX+CS50+2024", + "scope": "course", + "reason": "Repeated policy violations" + } + ``` + + Organization-level ban (requires global staff): + ```json + { + "username": "spammer123", + "course_id": "course-v1:HarvardX+CS50+2024", + "scope": "organization", + "reason": "Spam across multiple courses" + } + ``` + + **Response Values** + + * status: Success status + * message: Human-readable message + * ban_id: ID of the created ban record + * user_id: Banned user's ID + * username: Banned user's username + * scope: Scope of the ban + * course_id: Course ID (if course-level ban) + + **Notes** + + * Creates ban without deleting existing posts + * Course-level bans require course moderation permissions + * Organization-level bans require global staff permissions + * Reactivates existing inactive bans if found + * All ban actions are logged in ModerationAuditLog + """ + from forum import api as forum_api + from lms.djangoapps.discussion.rest_api.serializers import BanUserRequestSerializer + + # Check if ban API is available + if not hasattr(forum_api, 'ban_user') or not hasattr(forum_api, 'is_user_banned'): + return Response( + {'error': 'Ban functionality is not available in this forum version'}, + status=status.HTTP_501_NOT_IMPLEMENTED + ) + + serializer = BanUserRequestSerializer(data=request.data, context={'request': request}) + if not serializer.is_valid(): + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + # Validate and get user + result = self._validate_ban_request_and_get_user(request, serializer.validated_data) + if isinstance(result, Response): + return result + user, course_key, ban_scope, reason = result + + # Check permissions + permission_error = self._check_ban_permissions(request, ban_scope, course_key) + if permission_error: + return permission_error + + # Get or create ban + result = self._get_or_create_ban(user, course_key, ban_scope, reason, request) + if isinstance(result, Response): + return result + ban, action_type, message = result + + # Audit log + org_key = course_key.org if ban_scope == 'organization' else None + forum_api.create_audit_log( + action_type=action_type, + target_user=user, + moderator=request.user, + course_id=str(course_key), + scope=ban_scope, + reason=reason or 'No reason provided', + metadata={ + 'ban_id': ban['id'], + 'organization': org_key + } + ) + + return Response({ + 'status': 'success', + 'message': message, + 'ban_id': ban['id'], + 'user_id': user.id, + 'username': user.username, + 'scope': ban_scope, + 'course_id': str(course_key) if ban_scope == 'course' else None, + }, status=status.HTTP_201_CREATED) + + @apidocs.schema( + body=openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['course_id', 'scope'], + properties={ + 'user_id': openapi.Schema( + type=openapi.TYPE_INTEGER, + description='ID of the user to unban (required if username is not provided)' + ), + 'username': openapi.Schema( + type=openapi.TYPE_STRING, + description='Username of the user to unban (required if user_id is not provided)' + ), + 'course_id': openapi.Schema( + type=openapi.TYPE_STRING, + description='Course ID (e.g., course-v1:edX+DemoX+Demo_Course)' + ), + 'scope': openapi.Schema( + type=openapi.TYPE_STRING, + description='Scope of ban to lift: "course" or "organization"', + enum=['course', 'organization'] + ), + 'reason': openapi.Schema( + type=openapi.TYPE_STRING, + description='Reason for unbanning', + max_length=1000 + ), + }, + ), + responses={ + 200: openapi.Response( + description='User unbanned successfully', + schema=openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'status': openapi.Schema(type=openapi.TYPE_STRING, example='success'), + 'message': openapi.Schema(type=openapi.TYPE_STRING), + 'ban_id': openapi.Schema(type=openapi.TYPE_INTEGER), + 'user_id': openapi.Schema(type=openapi.TYPE_INTEGER), + 'username': openapi.Schema(type=openapi.TYPE_STRING), + 'scope': openapi.Schema(type=openapi.TYPE_STRING), + }, + ), + ), + 400: 'Invalid request data or user not currently banned.', + 401: 'The requester is not authenticated.', + 403: 'The requester does not have permission to unban users.', + 404: 'The specified user or ban does not exist.', + }, + ) + def unban_user(self, request): + """ + Unban a user from discussions. + + **Use Cases** + + * Lift ban after user appeal + * Remove accidental or temporary bans + * Restore discussion access + + **Example Requests** + + POST /api/discussion/v1/moderation/unban-user/ + + Course-level unban: + ```json + { + "user_id": 12345, + "course_id": "course-v1:HarvardX+CS50+2024", + "scope": "course", + "reason": "User appealed and corrected behavior" + } + ``` + + Organization-level unban: + ```json + { + "username": "student123", + "course_id": "course-v1:HarvardX+CS50+2024", + "scope": "organization", + "reason": "Ban lifted after review" + } + ``` + + **Response Values** + + * status: Success status + * message: Human-readable message + * ban_id: ID of the unbanned record + * user_id: Unbanned user's ID + * username: Unbanned user's username + * scope: Scope of the ban that was lifted + + **Notes** + + * Deactivates the ban without deleting the record + * Course-level unbans require course moderation permissions + * Organization-level unbans require global staff permissions + * All unban actions are logged in ModerationAuditLog + """ + from forum import api as forum_api + from lms.djangoapps.discussion.rest_api.serializers import BanUserRequestSerializer + + # Check if ban API is available + if not hasattr(forum_api, 'unban_user') or not hasattr(forum_api, 'is_user_banned'): + return Response( + {'error': 'Ban functionality is not available in this forum version'}, + status=status.HTTP_501_NOT_IMPLEMENTED + ) + + serializer = BanUserRequestSerializer(data=request.data, context={'request': request}) + if not serializer.is_valid(): + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + # Validate and get user + result = self._validate_ban_request_and_get_user(request, serializer.validated_data) + if isinstance(result, Response): + return result + + user, course_key, ban_scope, reason = result + + # Permission check + permission_error = self._check_ban_permissions(request, ban_scope, course_key) + if permission_error: + return permission_error + + # Check if user has an active ban + if not forum_api.is_user_banned(user, course_key, check_org=(ban_scope == 'organization')): + return Response( + { + 'error': f'User {user.username} does not have an active ban at {ban_scope} level', + }, + status=status.HTTP_400_BAD_REQUEST + ) + + # Get ban details before unbanning + ban_data = forum_api.get_ban( + user=user, + course_id=course_key, + scope=ban_scope + ) + + # Unban using forum API + # For org-level bans, pass course_id=None to fully unban across org + # (passing course_id for org ban creates an exception instead) + # NOTE: The newer /moderation/{pk}/unban/ endpoint (line ~2912) correctly + # supports optional course_id for creating exceptions. This older endpoint + # should always fully unban when scope='organization'. + unban_result = forum_api.unban_user( + user=user, + unbanned_by=request.user, + course_id=course_key if ban_scope == 'course' else None, + scope=ban_scope + ) + + # Prepare ban parameters based on scope + org_key = course_key.org if ban_scope == 'organization' else None + + # Audit log + forum_api.create_audit_log( + action_type='unban_user', + target_user=user, + moderator=request.user, + course_id=str(course_key), + scope=ban_scope, + reason=reason or 'No reason provided', + metadata={ + 'ban_id': ban_data.get('id') if ban_data else None, + 'organization': org_key + }, + ) + + return Response({ + 'status': 'success', + 'message': f'User {user.username} unbanned at {ban_scope} level', + 'ban_id': ban_data.get('id') if ban_data else None, + 'user_id': user.id, + 'username': user.username, + 'scope': ban_scope, + }, status=status.HTTP_200_OK) + + @apidocs.schema( + body=openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['user_id', 'course_id'], + properties={ + 'user_id': openapi.Schema( + type=openapi.TYPE_INTEGER, + description='ID of the user whose posts should be deleted' + ), + 'course_id': openapi.Schema( + type=openapi.TYPE_STRING, + description='Course ID (e.g., course-v1:edX+DemoX+Demo_Course)' + ), + 'ban_user': openapi.Schema( + type=openapi.TYPE_BOOLEAN, + description='If true, ban the user after deleting posts', + default=False + ), + 'ban_scope': openapi.Schema( + type=openapi.TYPE_STRING, + description='Scope of ban: "course" or "organization"', + enum=['course', 'organization'], + default='course' + ), + 'reason': openapi.Schema( + type=openapi.TYPE_STRING, + description='Reason for ban (required if ban_user is true)', + max_length=1000 + ), + }, + ), + responses={ + 202: openapi.Response( + description='Deletion task queued successfully', + schema=openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'status': openapi.Schema(type=openapi.TYPE_STRING, example='success'), + 'message': openapi.Schema(type=openapi.TYPE_STRING), + 'task_id': openapi.Schema(type=openapi.TYPE_STRING), + }, + ), + ), + 400: 'Invalid request data or missing required parameters.', + 401: 'The requester is not authenticated.', + 403: 'The requester does not have permission to perform bulk delete.', + 404: 'The specified user does not exist.', + }, + ) + def bulk_delete_ban(self, request): + """ + Delete all user posts in a course and optionally ban the user. + + **Use Cases** + + * Remove all discussion content from a spam account + * Ban user from course or organization discussions + * Bulk cleanup of policy-violating content + + **Example Request** + + POST /api/discussion/v1/moderation/bulk-delete-ban/ + + ```json + { + "user_id": 12345, + "course_id": "course-v1:HarvardX+CS50+2024", + "ban_user": true, + "ban_scope": "course", + "reason": "Posting spam and scam content" + } + ``` + + **Response Values** + + * status: Success status of the request + * message: Human-readable message about the queued task + * task_id: Celery task ID for tracking the asynchronous operation + + **Notes** + + * This operation is asynchronous and returns a task ID + * If ban_user is true, a ban record will be created after content deletion + * Reason is required when ban_user is true + * Email notification is sent to partner-support upon ban + * Staff and privileged users cannot be banned + """ + from lms.djangoapps.discussion.rest_api.serializers import BulkDeleteBanRequestSerializer + from lms.djangoapps.discussion.rest_api.utils import _is_privileged_user + + serializer = BulkDeleteBanRequestSerializer(data=request.data, context={'request': request}) + if not serializer.is_valid(): + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + validated_data = serializer.validated_data + try: + course_key = CourseKey.from_string(validated_data['course_id']) + except InvalidKeyError: + return Response( + {'error': f"Invalid course_id: {validated_data['course_id']}"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + target_user = User.objects.get(id=validated_data['user_id']) + except User.DoesNotExist: + return Response( + {'error': f'User with ID {validated_data["user_id"]} does not exist'}, + status=status.HTTP_404_NOT_FOUND + ) + + # Check if target user is staff/privileged - they shouldn't be banned + if validated_data['ban_user'] and _is_privileged_user(target_user, course_key): + return Response( + { + 'error': ( + f'Cannot ban staff or privileged users. User {target_user.username} ' + f'has elevated permissions in this course.' + ) + }, + status=status.HTTP_400_BAD_REQUEST + ) + + # Check if ban feature is enabled for this course + if validated_data['ban_user']: + if not ENABLE_DISCUSSION_BAN.is_enabled(course_key): + return Response( + {'error': 'Discussion ban feature is not enabled for this course'}, + status=status.HTTP_403_FORBIDDEN + ) + + # Enqueue Celery task (backward compatible with new parameters) + task = delete_course_post_for_user.apply_async( + kwargs={ + 'user_id': validated_data['user_id'], + 'username': target_user.username, + 'course_ids': [validated_data['course_id']], + 'ban_user': validated_data['ban_user'], + 'ban_scope': validated_data.get('ban_scope', 'course'), + 'moderator_id': request.user.id, + 'reason': validated_data.get('reason', ''), + } + ) + + message = ( + 'Deletion task queued. User will be banned upon completion.' + if validated_data['ban_user'] + else 'Deletion task queued.' + ) + return Response({ + 'status': 'success', + 'message': message, + 'task_id': task.id, + }, status=status.HTTP_202_ACCEPTED) + + @apidocs.schema( + parameters=[ + apidocs.string_parameter( + 'course_id', + apidocs.ParameterLocation.PATH, + description='Course ID to retrieve banned users for (required)' + ), + apidocs.string_parameter( + 'scope', + apidocs.ParameterLocation.QUERY, + description='Filter by ban scope: "course" or "organization"' + ), + ], + responses={ + 200: openapi.Response( + description='List of banned users', + schema=openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'count': openapi.Schema( + type=openapi.TYPE_INTEGER, + description='Total number of banned users' + ), + 'results': openapi.Schema( + type=openapi.TYPE_ARRAY, + description='Array of banned user records', + items=openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'id': openapi.Schema(type=openapi.TYPE_INTEGER), + 'username': openapi.Schema(type=openapi.TYPE_STRING), + 'email': openapi.Schema(type=openapi.TYPE_STRING), + 'user_id': openapi.Schema(type=openapi.TYPE_INTEGER), + 'course_id': openapi.Schema(type=openapi.TYPE_STRING), + 'organization': openapi.Schema(type=openapi.TYPE_STRING), + 'scope': openapi.Schema(type=openapi.TYPE_STRING), + 'reason': openapi.Schema(type=openapi.TYPE_STRING), + 'banned_at': openapi.Schema(type=openapi.TYPE_STRING, format='date-time'), + 'banned_by_username': openapi.Schema(type=openapi.TYPE_STRING), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN), + }, + ), + ), + }, + ), + ), + 400: 'Missing required course_id parameter.', + 401: 'The requester is not authenticated.', + 403: 'The requester does not have permission to view banned users.', + }, + ) + def banned_users(self, request, course_id=None): + """ + Retrieve list of banned users for a specific course. + + **Use Cases** + + * View all currently banned users in a course + * Filter banned users by scope (course-level vs organization-level) + * Audit moderation actions + * Unban users who were mistakenly banned (including staff) + + **Example Requests** + + GET /api/discussion/v1/moderation/banned-users/course-v1:HarvardX+CS50+2024 + GET /api/discussion/v1/moderation/banned-users/course-v1:edX+DemoX+Demo?scope=course + + **Response Values** + + * count: Total number of active bans for the course + * results: Array of ban records with user information (deduplicated) + + **Notes** + + * Only returns active bans (is_active=True) + * Course-level bans are specific to one course + * Organization-level bans apply to all courses in the organization + * Shows ALL banned users including staff (so they can be unbanned if mistakenly banned) + * Deduplicates users with multiple ban records (e.g., course-level + org-level) + * New bans of staff are prevented by validation in ban endpoints + """ + from forum import api as forum_api + from lms.djangoapps.discussion.rest_api.permissions import can_take_action_on_spam + + if not course_id: + return Response( + {'error': 'course_id parameter is required'}, + status=status.HTTP_400_BAD_REQUEST + ) + + try: + course_key = CourseKey.from_string(course_id) + except InvalidKeyError: + return Response( + {'error': f'Invalid course_id: {course_id}'}, + status=status.HTTP_400_BAD_REQUEST, + ) + + # Permission check: user must be able to moderate in this course + if not can_take_action_on_spam(request.user, course_key): + return Response( + {'error': 'You do not have permission to view banned users in this course'}, + status=status.HTTP_403_FORBIDDEN + ) + + # Check if ban feature is enabled for this course + if not ENABLE_DISCUSSION_BAN.is_enabled(course_key): + return Response( + {'error': 'Discussion ban feature is not enabled for this course'}, + status=status.HTTP_403_FORBIDDEN + ) + + # Optional scope filter + scope = request.query_params.get('scope') + + # Get banned users using forum API + banned_users_data = forum_api.get_banned_users( + course_id=course_key, + scope=scope + ) + + # Deduplicate by user_id (user may have both course-level and org-level bans) + # Keep the first occurrence (most relevant ban record) + seen_user_ids = set() + deduplicated_banned_users = [] + for ban in banned_users_data: + user_id = ban.get('user', {}).get('id') + if user_id and user_id not in seen_user_ids: + seen_user_ids.add(user_id) + deduplicated_banned_users.append(ban) + + return Response({ + 'count': len(deduplicated_banned_users), + 'results': deduplicated_banned_users + }) + + @apidocs.schema( + parameters=[ + apidocs.string_parameter( + 'pk', + apidocs.ParameterLocation.PATH, + description='Ban ID to unban' + ), + ], + body=openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'course_id': openapi.Schema( + type=openapi.TYPE_STRING, + description='Course ID for organization-level ban exceptions' + ), + 'reason': openapi.Schema( + type=openapi.TYPE_STRING, + description='Reason for unbanning' + ), + }, + required=['reason'], + ), + responses={ + 200: openapi.Response( + description='User unbanned successfully', + schema=openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'status': openapi.Schema(type=openapi.TYPE_STRING, example='success'), + 'message': openapi.Schema(type=openapi.TYPE_STRING), + 'exception_created': openapi.Schema( + type=openapi.TYPE_BOOLEAN, + description='True if org-level ban exception was created' + ), + }, + ), + ), + 401: 'The requester is not authenticated.', + 403: 'The requester does not have permission to unban users.', + 404: 'Active ban not found with the specified ID.', + }, + ) + def unban_user_by_id(self, request, pk=None): + """ + Unban a user from discussions or create course-level exception (by ban ID). + + **Use Cases** + + * Lift a course-level ban completely + * Lift an organization-level ban completely + * Create course-specific exception to organization-level ban + * Process user appeals + + **Example Requests** + + POST /api/discussion/v1/moderation/123/unban/ + + ```json + { + "reason": "User appeal approved - first offense" + } + ``` + + Create exception for org-level ban: + + ```json + { + "course_id": "course-v1:HarvardX+CS50+2024", + "reason": "Exception approved for CS50 only" + } + ``` + + **Response Values** + + * status: Success status of the operation + * message: Human-readable message describing the action taken + * exception_created: Boolean indicating if an org-level exception was created + + **Notes** + + * For course-level bans: Deactivates the ban completely + * For org-level bans without course_id: Deactivates entire org-level ban + * For org-level bans with course_id: Creates exception allowing user in that course only + * All unban actions are logged in ModerationAuditLog + """ + from forum import api as forum_api + from lms.djangoapps.discussion.rest_api.permissions import can_take_action_on_spam + + # Get ban using forum API + try: + ban = forum_api.get_ban(pk) + except Exception: # pylint: disable=broad-exception-caught + return Response( + {'error': 'Active ban not found'}, + status=status.HTTP_404_NOT_FOUND + ) + + # Check if ban is active + if not ban.get('is_active'): + return Response( + {'error': 'Active ban not found'}, + status=status.HTTP_404_NOT_FOUND + ) + + course_id = request.data.get('course_id') + reason = request.data.get('reason', '').strip() + parsed_course_key = None + + if course_id: + try: + parsed_course_key = CourseKey.from_string(course_id) + except InvalidKeyError: + return Response( + {'error': f'Invalid course_id: {course_id}'}, + status=status.HTTP_400_BAD_REQUEST, + ) + + # Import dependencies + from common.djangoapps.student.roles import GlobalStaff + + # Permission check: depends on ban type and what user is trying to do + ban_course_id = ban.get('course_id') + if ban_course_id: + # Course-level ban - check permissions for that specific course + course_key_obj = CourseKey.from_string(ban_course_id) + if not can_take_action_on_spam(request.user, course_key_obj): + return Response( + {'error': 'You do not have permission to unban users in this course'}, + status=status.HTTP_403_FORBIDDEN + ) + else: + # Org-level ban + if course_id: + # Creating exception for specific course - check permissions in that course + if not can_take_action_on_spam(request.user, parsed_course_key): + return Response( + {'error': 'You do not have permission to create exceptions in this course'}, + status=status.HTTP_403_FORBIDDEN + ) + else: + # Fully unbanning org-level ban - only global staff can do this + if not (GlobalStaff().has_user(request.user) or request.user.is_staff): + return Response( + {'error': 'Only global staff can fully unban organization-level bans'}, + status=status.HTTP_403_FORBIDDEN + ) + + # Check if ban feature is enabled + # Determine which course_key to use for flag check + if ban_course_id: + # Course-level ban - use ban's course_id + course_key_for_flag = CourseKey.from_string(ban_course_id) + elif course_id: + # Org-level ban with course exception - use provided course_id + course_key_for_flag = parsed_course_key + elif ban.get('scope') == 'organization' and ban.get('org_key'): + # Org-level ban without course_id - find any course in org to check flag + from openedx.core.djangoapps.content.course_overviews.models import CourseOverview + try: + # Find any course in the organization to check the flag + org_course = CourseOverview.objects.filter(org=ban['org_key']).first() + if org_course: + course_key_for_flag = org_course.id + else: + # No courses found in org - deny unless global staff + if not (GlobalStaff().has_user(request.user) or request.user.is_staff): + return Response( + {'error': 'Discussion ban feature check requires course context or global staff access'}, + status=status.HTTP_403_FORBIDDEN + ) + # Global staff can proceed without flag check for org-level operations + course_key_for_flag = None + except Exception: # pylint: disable=broad-exception-caught + # Fallback: deny unless global staff + if not (GlobalStaff().has_user(request.user) or request.user.is_staff): + return Response( + {'error': 'Discussion ban feature check requires course context or global staff access'}, + status=status.HTTP_403_FORBIDDEN + ) + course_key_for_flag = None + else: + course_key_for_flag = None + + # Check flag if we have a course_key + if course_key_for_flag: + if not ENABLE_DISCUSSION_BAN.is_enabled(course_key_for_flag): + return Response( + {'error': 'Discussion ban feature is not enabled for this course'}, + status=status.HTTP_403_FORBIDDEN + ) + + # Validate that reason is provided + if not reason: + return Response( + {'error': 'reason field is required'}, + status=status.HTTP_400_BAD_REQUEST + ) + + # Use forum API to unban - it handles both full unban and exceptions + try: + unban_result = forum_api.unban_user( + ban_id=pk, + unbanned_by=request.user, + course_id=course_id, + reason=reason + ) + + return Response({ + 'status': unban_result.get('status', 'success'), + 'message': unban_result.get('message', 'User unbanned successfully'), + 'exception_created': unban_result.get('exception_created', False) + }) + except ValueError as e: + return Response( + {'error': str(e)}, + status=status.HTTP_404_NOT_FOUND + ) + except Exception as e: # pylint: disable=broad-exception-caught + log.error(f"Error unbanning user: {e}") + return Response( + {'error': 'An error occurred while unbanning the user'}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR + ) diff --git a/lms/djangoapps/discussion/templates/discussion/ban_escalation_email.txt b/lms/djangoapps/discussion/templates/discussion/ban_escalation_email.txt new file mode 100644 index 000000000000..f8a02b664b18 --- /dev/null +++ b/lms/djangoapps/discussion/templates/discussion/ban_escalation_email.txt @@ -0,0 +1,28 @@ +DISCUSSION BAN ALERT +================================================================================ + +A user has been banned from course discussions. + +Banned User: {{ banned_username }} ({{ banned_email }}) +User ID: {{ banned_user_id }} + +Moderator: {{ moderator_username }} ({{ moderator_email }}) +Moderator ID: {{ moderator_id }} + +Course ID: {{ course_id }} +Ban Scope: {{ scope|upper }} + +Reason: {{ reason }} + +Content Deleted: +- Threads: {{ threads_deleted }} +- Comments: {{ comments_deleted }} +- Total: {{ total_deleted }} + +================================================================================ + +ACTION REQUIRED: +Please review this moderation action and follow up as needed. If this ban was +applied in error or requires adjustment, contact the moderator or course staff. + +================================================================================ diff --git a/lms/djangoapps/discussion/templates/discussion/edx_ace/ban_escalation/email/body.html b/lms/djangoapps/discussion/templates/discussion/edx_ace/ban_escalation/email/body.html new file mode 100644 index 000000000000..d6a57962433f --- /dev/null +++ b/lms/djangoapps/discussion/templates/discussion/edx_ace/ban_escalation/email/body.html @@ -0,0 +1,82 @@ +{% extends 'ace_common/edx_ace/common/base_body.html' %} + +{% load i18n %} +{% load django_markup %} +{% load static %} + +{% block content %} + + + + +
+

+ {% filter force_escape %} + {% blocktrans %} + Discussion Ban Alert + {% endblocktrans %} + {% endfilter %} +

+ +

+ {% filter force_escape %} + {% blocktrans %} + A user has been banned from course discussions. Please review this moderation action. + {% endblocktrans %} + {% endfilter %} +

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Banned User{{ banned_username }} ({{ banned_email }})
User ID{{ banned_user_id }}
Moderator{{ moderator_username }} ({{ moderator_email }})
Moderator ID{{ moderator_id }}
Course ID{{ course_id }}
Ban Scope + {{ scope|upper }}{% if scope == 'organization' %} (All courses in organization){% endif %} +
Reason{{ reason }}
Content Deleted + {{ threads_deleted }} thread{{ threads_deleted|pluralize }}, + {{ comments_deleted }} comment{{ comments_deleted|pluralize }} +
+ Total: {{ total_deleted }} item{{ total_deleted|pluralize }} +
+ +

+ {% trans "Action Required:" as action_required %}{{ action_required|force_escape }} + {% trans "Please review this moderation action and follow up as needed. If this ban was applied in error or requires adjustment, contact the moderator or course staff." as review_instructions %}{{ review_instructions|force_escape }} +

+ + {% block google_analytics_pixel %} + {% if ga_tracking_pixel_url %} + + {% endif %} + {% endblock %} +
+{% endblock %} diff --git a/lms/djangoapps/discussion/templates/discussion/edx_ace/ban_escalation/email/body.txt b/lms/djangoapps/discussion/templates/discussion/edx_ace/ban_escalation/email/body.txt new file mode 100644 index 000000000000..f8a02b664b18 --- /dev/null +++ b/lms/djangoapps/discussion/templates/discussion/edx_ace/ban_escalation/email/body.txt @@ -0,0 +1,28 @@ +DISCUSSION BAN ALERT +================================================================================ + +A user has been banned from course discussions. + +Banned User: {{ banned_username }} ({{ banned_email }}) +User ID: {{ banned_user_id }} + +Moderator: {{ moderator_username }} ({{ moderator_email }}) +Moderator ID: {{ moderator_id }} + +Course ID: {{ course_id }} +Ban Scope: {{ scope|upper }} + +Reason: {{ reason }} + +Content Deleted: +- Threads: {{ threads_deleted }} +- Comments: {{ comments_deleted }} +- Total: {{ total_deleted }} + +================================================================================ + +ACTION REQUIRED: +Please review this moderation action and follow up as needed. If this ban was +applied in error or requires adjustment, contact the moderator or course staff. + +================================================================================ diff --git a/lms/djangoapps/discussion/templates/discussion/edx_ace/ban_escalation/email/from_name.txt b/lms/djangoapps/discussion/templates/discussion/edx_ace/ban_escalation/email/from_name.txt new file mode 100644 index 000000000000..fb090bda4e0e --- /dev/null +++ b/lms/djangoapps/discussion/templates/discussion/edx_ace/ban_escalation/email/from_name.txt @@ -0,0 +1 @@ +edX Discussion Moderation \ No newline at end of file diff --git a/lms/djangoapps/discussion/templates/discussion/edx_ace/ban_escalation/email/head.html b/lms/djangoapps/discussion/templates/discussion/edx_ace/ban_escalation/email/head.html new file mode 100644 index 000000000000..366ada7ad92e --- /dev/null +++ b/lms/djangoapps/discussion/templates/discussion/edx_ace/ban_escalation/email/head.html @@ -0,0 +1 @@ +{% extends 'ace_common/edx_ace/common/base_head.html' %} diff --git a/lms/djangoapps/discussion/templates/discussion/edx_ace/ban_escalation/email/subject.txt b/lms/djangoapps/discussion/templates/discussion/edx_ace/ban_escalation/email/subject.txt new file mode 100644 index 000000000000..a3e4a972368b --- /dev/null +++ b/lms/djangoapps/discussion/templates/discussion/edx_ace/ban_escalation/email/subject.txt @@ -0,0 +1 @@ +Discussion Moderation Alert: User Banned \ No newline at end of file diff --git a/lms/djangoapps/discussion/toggles.py b/lms/djangoapps/discussion/toggles.py index 61286686b759..9de8d0250555 100644 --- a/lms/djangoapps/discussion/toggles.py +++ b/lms/djangoapps/discussion/toggles.py @@ -37,3 +37,20 @@ # .. toggle_creation_date: 2025-07-29 # .. toggle_target_removal_date: 2026-07-29 ENABLE_RATE_LIMIT_IN_DISCUSSION = CourseWaffleFlag(f'{WAFFLE_FLAG_NAMESPACE}.enable_rate_limit', __name__) + + +# .. toggle_name: discussions.enable_discussion_ban +# .. toggle_implementation: CourseWaffleFlag +# .. toggle_default: False +# .. toggle_description: Waffle flag to enable ban user functionality in discussion moderation. +# When enabled, moderators can ban users from discussions at course or organization level +# during bulk delete operations. This addresses crypto spam attacks and harassment. +# .. toggle_use_cases: opt_in +# .. toggle_creation_date: 2024-11-24 +# .. toggle_target_removal_date: 2025-06-01 +# .. toggle_warning: This feature requires proper moderator training to prevent misuse. +# Ensure DISCUSSION_MODERATION_BAN_EMAIL_ENABLED is configured appropriately for your environment. +# .. toggle_tickets: COSMO2-736 +ENABLE_DISCUSSION_BAN = CourseWaffleFlag( + f'{WAFFLE_FLAG_NAMESPACE}.enable_discussion_ban', __name__ +) diff --git a/lms/djangoapps/discussion/views.py b/lms/djangoapps/discussion/views.py index c01fb24b073a..9bbed672249a 100644 --- a/lms/djangoapps/discussion/views.py +++ b/lms/djangoapps/discussion/views.py @@ -166,6 +166,7 @@ def get_threads(request, course, user_info, discussion_id=None, per_page=THREADS 'flagged', 'unread', 'unanswered', + 'context', ] ) ) @@ -187,6 +188,12 @@ def get_threads(request, course, user_info, discussion_id=None, per_page=THREADS if 'pinned' not in thread: thread['pinned'] = False + # Filter team discussions - only team members can see team posts + if discussion_id is not None and not is_privileged_user(course.id, request.user): + team = team_api.get_team_by_discussion(discussion_id) + if team and not team.users.filter(id=request.user.id).exists(): + threads = [] + query_params['page'] = paginated_results.page query_params['num_pages'] = paginated_results.num_pages query_params['corrected_text'] = paginated_results.corrected_text diff --git a/lms/djangoapps/grades/subsection_grade.py b/lms/djangoapps/grades/subsection_grade.py index 4ce0a1f3a463..b0c98497b823 100644 --- a/lms/djangoapps/grades/subsection_grade.py +++ b/lms/djangoapps/grades/subsection_grade.py @@ -5,8 +5,8 @@ from abc import ABCMeta from collections import OrderedDict +from datetime import datetime, timezone from logging import getLogger - from lazy import lazy from lms.djangoapps.grades.models import BlockRecord, PersistentSubsectionGrade @@ -59,6 +59,13 @@ def show_grades(self, has_staff_access): """ Returns whether subsection scores are currently available to users with or without staff access. """ + if self.show_correctness == ShowCorrectness.NEVER_BUT_INCLUDE_GRADE: + # show_grades fn is used to determine if the grade should be included in final calculation. + # For NEVER_BUT_INCLUDE_GRADE, show_grades returns True if the due date has passed, + # but correctness_available always returns False as we do not want to show correctness + # of problems to the users. + return (self.due is None or + self.due < datetime.now(timezone.utc)) return ShowCorrectness.correctness_available(self.show_correctness, self.due, has_staff_access) @property diff --git a/lms/djangoapps/grades/tests/test_course_grade_factory.py b/lms/djangoapps/grades/tests/test_course_grade_factory.py index d7d3a20c1ff0..b6865d225fde 100644 --- a/lms/djangoapps/grades/tests/test_course_grade_factory.py +++ b/lms/djangoapps/grades/tests/test_course_grade_factory.py @@ -71,28 +71,28 @@ def _assert_section_order(course_grade): with self.assertNumQueries(3), mock_get_score(1, 2): _assert_read(expected_pass=False, expected_percent=0) # start off with grade of 0 - num_queries = 42 + num_queries = 44 with self.assertNumQueries(num_queries), mock_get_score(1, 2): grade_factory.update(self.request.user, self.course, force_update_subsections=True) with self.assertNumQueries(3): _assert_read(expected_pass=True, expected_percent=0.5) # updated to grade of .5 - num_queries = 6 + num_queries = 8 with self.assertNumQueries(num_queries), mock_get_score(1, 4): grade_factory.update(self.request.user, self.course, force_update_subsections=False) with self.assertNumQueries(3): _assert_read(expected_pass=True, expected_percent=0.5) # NOT updated to grade of .25 - num_queries = 18 + num_queries = 20 with self.assertNumQueries(num_queries), mock_get_score(2, 2): grade_factory.update(self.request.user, self.course, force_update_subsections=True) with self.assertNumQueries(3): _assert_read(expected_pass=True, expected_percent=1.0) # updated to grade of 1.0 - num_queries = 28 + num_queries = 30 with self.assertNumQueries(num_queries), mock_get_score(0, 0): # the subsection now is worth zero grade_factory.update(self.request.user, self.course, force_update_subsections=True) diff --git a/lms/djangoapps/instructor_task/api.py b/lms/djangoapps/instructor_task/api.py index 6474efc1d374..4cb9c53f4d9c 100644 --- a/lms/djangoapps/instructor_task/api.py +++ b/lms/djangoapps/instructor_task/api.py @@ -52,6 +52,7 @@ generate_anonymous_ids_for_course ) from xmodule.modulestore.django import modulestore # lint-amnesty, pylint: disable=wrong-import-order +from django.db.models import Q log = logging.getLogger(__name__) @@ -82,12 +83,31 @@ def get_instructor_task_history(course_id, usage_key=None, student=None, task_ty that optionally match a particular problem, a student, and/or a task type. """ instructor_tasks = InstructorTask.objects.filter(course_id=course_id) + if usage_key is not None or student is not None: _, task_key = encode_problem_and_student_input(usage_key, student) instructor_tasks = instructor_tasks.filter(task_key=task_key) if task_type is not None: instructor_tasks = instructor_tasks.filter(task_type=task_type) + # Bulk email history is user-facing; only show tasks that represent + # real delivered emails (SUCCESS with succeeded > 0) or future scheduled sends. + if task_type == InstructorTaskTypes.BULK_COURSE_EMAIL: + instructor_tasks = instructor_tasks.filter( + # SUCCESS tasks must have delivery results, while SCHEDULED tasks + # have no task_output yet and must be included explicitly. + Q( + task_state='SUCCESS', + task_output__contains='"succeeded":' + ) | + Q( + task_state='SCHEDULED' + ) + ).exclude( + # Exclude completed tasks where no emails were actually sent + task_output__contains='"succeeded": 0' + ) + return instructor_tasks.order_by('-id') diff --git a/lms/djangoapps/instructor_task/tests/test_get_instructor_task_history.py b/lms/djangoapps/instructor_task/tests/test_get_instructor_task_history.py new file mode 100644 index 000000000000..15be7adf9b51 --- /dev/null +++ b/lms/djangoapps/instructor_task/tests/test_get_instructor_task_history.py @@ -0,0 +1,203 @@ +""" +Tests for get_instructor_task_history in bulk email. +""" +import json +from celery.states import SUCCESS, FAILURE, REVOKED + +from lms.djangoapps.instructor_task.api import get_instructor_task_history +from lms.djangoapps.instructor_task.tests.test_base import InstructorTaskCourseTestCase +from lms.djangoapps.instructor_task.tests.factories import InstructorTaskFactory + + +class TestGetInstructorTaskHistory(InstructorTaskCourseTestCase): + """ + Tests for updated filtering logic in get_instructor_task_history + + Rules: + - SUCCESS tasks must contain succeeded > 0 in task_output + - SCHEDULED tasks must be included even if task_output is empty + - SUCCESS tasks with succeeded = 0 must be excluded + - FAILED / REVOKED tasks must be excluded + """ + + def setUp(self): + super().setUp() + self.initialize_course() + self.instructor = self.create_instructor('instructor') + + def test_includes_successful_bulk_email_task(self): + """ + SUCCESS + succeeded > 0 → INCLUDED + """ + task_output = json.dumps({ + "attempted": 10, + "succeeded": 10, + "failed": 0 + }) + + success_task = InstructorTaskFactory.create( + course_id=self.course.id, + task_type="bulk_course_email", + task_id="bulk_email_success", + task_input='{}', + task_state=SUCCESS, + task_output=task_output, + task_key='bulk_email_success', + requester=self.instructor + ) + + tasks = list(get_instructor_task_history( + self.course.id, + task_type="bulk_course_email" + )) + + assert success_task in tasks + + def test_includes_scheduled_task_with_empty_output(self): + """ + SCHEDULED (even with empty {}) → INCLUDED + """ + scheduled_task = InstructorTaskFactory.create( + course_id=self.course.id, + task_type="bulk_course_email", + task_id="bulk_email_scheduled", + task_input='{}', + task_state="SCHEDULED", + task_output="{}", + task_key='bulk_email_scheduled', + requester=self.instructor + ) + + tasks = list(get_instructor_task_history( + self.course.id, + task_type="bulk_course_email" + )) + + assert scheduled_task in tasks + + def test_excludes_zero_success_tasks(self): + """ + SUCCESS + succeeded = 0 → EXCLUDED + """ + zero_success_task = InstructorTaskFactory.create( + course_id=self.course.id, + task_type="bulk_course_email", + task_id="bulk_email_zero", + task_state=SUCCESS, + task_output=json.dumps({ + "attempted": 10, + "succeeded": 0, + "failed": 10 + }), + task_key='bulk_email_zero', + requester=self.instructor + ) + + tasks = list(get_instructor_task_history( + self.course.id, + task_type="bulk_course_email" + )) + + assert zero_success_task not in tasks + + def test_excludes_failed_tasks(self): + """ + FAILURE → EXCLUDED + """ + failed_task = InstructorTaskFactory.create( + course_id=self.course.id, + task_type="bulk_course_email", + task_id="bulk_email_failed", + task_state=FAILURE, + task_output=json.dumps({ + "attempted": 5, + "succeeded": 0, + "failed": 5 + }), + task_key='bulk_email_failed', + requester=self.instructor + ) + + tasks = list(get_instructor_task_history( + self.course.id, + task_type="bulk_course_email" + )) + + assert failed_task not in tasks + + def test_excludes_revoked_tasks(self): + """ + REVOKED → EXCLUDED + """ + revoked_task = InstructorTaskFactory.create( + course_id=self.course.id, + task_type="bulk_course_email", + task_id="bulk_email_revoked", + task_state=REVOKED, + task_output='{"message": "Task revoked"}', + task_key='bulk_email_revoked', + requester=self.instructor + ) + + tasks = list(get_instructor_task_history( + self.course.id, + task_type="bulk_course_email" + )) + + assert revoked_task not in tasks + + def test_only_valid_tasks_returned(self): + """ + Only the following should be returned: + - SUCCESS with succeeded > 0 + - SCHEDULED + + Everything else must be excluded. + """ + valid_success = InstructorTaskFactory.create( + course_id=self.course.id, + task_type="bulk_course_email", + task_id="bulk_email_valid", + task_state=SUCCESS, + task_output=json.dumps({ + "attempted": 8, + "succeeded": 5, + "failed": 3 + }), + task_key='bulk_email_valid', + requester=self.instructor + ) + + scheduled = InstructorTaskFactory.create( + course_id=self.course.id, + task_type="bulk_course_email", + task_id="bulk_email_scheduled_2", + task_state="SCHEDULED", + task_output="{}", + task_key='bulk_email_scheduled_2', + requester=self.instructor + ) + + zero_task = InstructorTaskFactory.create( + course_id=self.course.id, + task_type="bulk_course_email", + task_id="bulk_email_zero_2", + task_state=SUCCESS, + task_output=json.dumps({ + "attempted": 5, + "succeeded": 0, + "failed": 5 + }), + task_key='bulk_email_zero_2', + requester=self.instructor + ) + + tasks = list(get_instructor_task_history( + self.course.id, + task_type="bulk_course_email" + )) + + assert valid_success in tasks + assert scheduled in tasks + assert zero_task not in tasks + assert len(tasks) == 2 diff --git a/lms/djangoapps/staticbook/tests.py b/lms/djangoapps/staticbook/tests.py index 919ee16c4fef..1d42ac4b5bcc 100644 --- a/lms/djangoapps/staticbook/tests.py +++ b/lms/djangoapps/staticbook/tests.py @@ -129,8 +129,11 @@ def test_book(self): url = self.make_url('pdf_book', book_index=0) response = self.client.get(url) self.assertContains(response, "Chapter 1 for PDF") - self.assertNotContains(response, "options.chapterNum =") - self.assertNotContains(response, "page=") + # Verify file parameter is not present (security fix) + self.assertNotContains(response, "file=") + # Verify postMessage infrastructure is in place + self.assertContains(response, "request_pdf_url") + self.assertContains(response, "pdf_url_response") def test_book_chapter(self): # We can access a book at a particular chapter. @@ -138,8 +141,10 @@ def test_book_chapter(self): url = self.make_url('pdf_book', book_index=0, chapter=2) response = self.client.get(url) self.assertContains(response, "Chapter 2 for PDF") - self.assertContains(response, "file={}".format(PDF_BOOK['chapters'][1]['url'])) - self.assertNotContains(response, "page=") + # Verify file parameter is not present anywhere (security fix) + self.assertNotContains(response, "file=") + # Verify postMessage infrastructure is in place + self.assertContains(response, "request_pdf_url") def test_book_page(self): # We can access a book at a particular page. @@ -147,7 +152,9 @@ def test_book_page(self): url = self.make_url('pdf_book', book_index=0, page=17) response = self.client.get(url) self.assertContains(response, "Chapter 1 for PDF") - self.assertNotContains(response, "options.chapterNum =") + # Verify file parameter is not present (security fix) + self.assertNotContains(response, "file=") + # Page parameter is still used in viewer_params self.assertContains(response, "page=17") def test_book_chapter_page(self): @@ -156,7 +163,9 @@ def test_book_chapter_page(self): url = self.make_url('pdf_book', book_index=0, chapter=2, page=17) response = self.client.get(url) self.assertContains(response, "Chapter 2 for PDF") - self.assertContains(response, "file={}".format(PDF_BOOK['chapters'][1]['url'])) + # Verify file parameter is not present (security fix) + self.assertNotContains(response, "file=") + # Page parameter is still used in viewer_params self.assertContains(response, "page=17") def test_bad_book_id(self): @@ -202,29 +211,32 @@ def test_chapter_page_xss(self): def test_static_url_map_contentstore(self): """ - This ensure static URL mapping is happening properly for - a course that uses the contentstore + This ensure static URL mapping is happening properly for + a course that uses the contentstore. + URLs are remapped in backend but not exposed via file parameter (security fix). """ self.make_course(pdf_textbooks=[PORTABLE_PDF_BOOK]) url = self.make_url('pdf_book', book_index=0, chapter=1) response = self.client.get(url) - self.assertNotContains(response, 'file={}'.format(PORTABLE_PDF_BOOK['chapters'][0]['url'])) - self.assertContains(response, 'file=/asset-v1:{0.org}+{0.course}+{0.run}+type@asset+block/{1}'.format( + # Verify file parameter is not present in response (security fix) + self.assertNotContains(response, 'file=') + # Verify the chapter URL is in the sidebar for postMessage communication + self.assertContains(response, '/asset-v1:{0.org}+{0.course}+{0.run}+type@asset+block/{1}'.format( self.course.location, PORTABLE_PDF_BOOK['chapters'][0]['url'].replace('/static/', ''))) def test_static_url_map_static_asset_path(self): """ - Like above, but used when the course has set a static_asset_path + Like above, but used when the course has set a static_asset_path. + URLs are remapped in backend but not exposed via file parameter (security fix). """ self.make_course(pdf_textbooks=[PORTABLE_PDF_BOOK], static_asset_path='awesomesauce') url = self.make_url('pdf_book', book_index=0, chapter=1) response = self.client.get(url) - self.assertNotContains(response, 'file={}'.format(PORTABLE_PDF_BOOK['chapters'][0]['url'])) - self.assertNotContains(response, 'file=/c4x/{0.org}/{0.course}/asset/{1}'.format( - self.course.location, - PORTABLE_PDF_BOOK['chapters'][0]['url'].replace('/static/', ''))) - self.assertContains(response, 'file=/static/awesomesauce/{}'.format( + # Verify file parameter is not present anywhere (security fix) + self.assertNotContains(response, 'file=') + # Verify the remapped URL is in the sidebar for postMessage communication + self.assertContains(response, '/static/awesomesauce/{}'.format( PORTABLE_PDF_BOOK['chapters'][0]['url'].replace('/static/', ''))) def test_invalid_chapter_id(self): diff --git a/lms/djangoapps/staticbook/views.py b/lms/djangoapps/staticbook/views.py index 2e6d1c9a8ed5..5c32f35f3b48 100644 --- a/lms/djangoapps/staticbook/views.py +++ b/lms/djangoapps/staticbook/views.py @@ -86,12 +86,11 @@ def pdf_index(request, course_id, book_index, chapter=None, page=None): raise Http404(f"Invalid book index value: {book_index}") textbook = course.pdf_textbooks[book_index] - viewer_params = '&file=' + viewer_params = '' current_url = '' if 'url' in textbook: textbook['url'] = remap_static_url(textbook['url'], course) - viewer_params += textbook['url'] current_url = textbook['url'] # then remap all the chapter URLs as well, if they are provided. @@ -99,18 +98,24 @@ def pdf_index(request, course_id, book_index, chapter=None, page=None): if 'chapters' in textbook: for entry in textbook['chapters']: entry['url'] = remap_static_url(entry['url'], course) + # Security: Validate chapter URL doesn't contain dangerous schemes + if entry['url'].lower().startswith(('javascript:', 'data:', 'vbscript:', 'file:')): + entry['url'] = '' # Sanitize dangerous URLs if chapter is not None and int(chapter) <= (len(textbook['chapters'])): current_chapter = textbook['chapters'][int(chapter) - 1] else: current_chapter = textbook['chapters'][0] - viewer_params += current_chapter['url'] + current_url = current_chapter['url'] viewer_params += '#zoom=page-fit&disableRange=true' if page is not None: viewer_params += f'&page={page}' - if request.GET.get('viewer', '') == 'true': + if current_url.startswith('https://'): + current_url = '' + template = 'static_pdfbook.html' + elif request.GET.get('viewer', '') == 'true': template = 'pdf_viewer.html' else: template = 'static_pdfbook.html' diff --git a/lms/envs/common.py b/lms/envs/common.py index 3dde7156b93e..cb6cc78e9fb1 100644 --- a/lms/envs/common.py +++ b/lms/envs/common.py @@ -3489,6 +3489,36 @@ AVAILABLE_DISCUSSION_TOURS = [] +############## DISCUSSION MODERATION ############## + +# .. toggle_name: settings.DISCUSSION_MODERATION_BAN_EMAIL_ENABLED +# .. toggle_implementation: DjangoSetting +# .. toggle_default: True +# .. toggle_description: Enable/disable email notifications when users are banned from discussions. +# Set to False in development/test environments to prevent spam to partner-support@edx.org. +# When enabled, escalation emails are sent to DISCUSSION_MODERATION_ESCALATION_EMAIL address. +# .. toggle_use_cases: opt_in +# .. toggle_creation_date: 2024-11-24 +# .. toggle_tickets: COSMO2-736 +DISCUSSION_MODERATION_BAN_EMAIL_ENABLED = True + +# .. setting_name: DISCUSSION_MODERATION_ESCALATION_EMAIL +# .. setting_default: 'partner-support@edx.org' +# .. setting_description: Email address to receive ban escalation notifications when users are banned +# from discussions. Override in development to use a test email address. +# .. setting_use_cases: opt_in +# .. setting_creation_date: 2024-11-24 +# .. setting_tickets: COSMO2-736 +DISCUSSION_MODERATION_ESCALATION_EMAIL = 'partner-support@edx.org' + +# .. setting_name: DISCUSSION_MODERATION_BAN_REASON_MAX_LENGTH +# .. setting_default: 1000 +# .. setting_description: Maximum character length for ban reason text. +# .. setting_use_cases: opt_in +# .. setting_creation_date: 2024-11-24 +# .. setting_tickets: COSMO2-736 +DISCUSSION_MODERATION_BAN_REASON_MAX_LENGTH = 1000 + ############## NOTIFICATIONS ############## NOTIFICATION_TYPE_ICONS = {} DEFAULT_NOTIFICATION_ICON_URL = "" diff --git a/lms/envs/devstack.py b/lms/envs/devstack.py index b15533855c7b..b20db319ff37 100644 --- a/lms/envs/devstack.py +++ b/lms/envs/devstack.py @@ -40,6 +40,10 @@ CLEAR_REQUEST_CACHE_ON_TASK_COMPLETION = False HTTPS = 'off' +# Disable ban emails in local development to prevent spam +DISCUSSION_MODERATION_BAN_EMAIL_ENABLED = False +DISCUSSION_MODERATION_ESCALATION_EMAIL = 'devnull@example.com' + LMS_ROOT_URL = f'http://{LMS_BASE}' LMS_INTERNAL_ROOT_URL = LMS_ROOT_URL ENTERPRISE_API_URL = f'{LMS_INTERNAL_ROOT_URL}/enterprise/api/v1/' diff --git a/lms/envs/test.py b/lms/envs/test.py index 218a7e8461b8..67b56a954404 100644 --- a/lms/envs/test.py +++ b/lms/envs/test.py @@ -72,6 +72,10 @@ # the one in cms/envs/test.py ENABLE_DISCUSSION_SERVICE = False +# Disable ban emails in tests to prevent spam and speed up tests +DISCUSSION_MODERATION_BAN_EMAIL_ENABLED = False +DISCUSSION_MODERATION_ESCALATION_EMAIL = 'test@example.com' + ENABLE_SERVICE_STATUS = True ENABLE_VERIFIED_CERTIFICATES = True diff --git a/lms/static/js/student_account/views/RegisterView.js b/lms/static/js/student_account/views/RegisterView.js index 42ab7c8857a8..bd958c2d8bd4 100644 --- a/lms/static/js/student_account/views/RegisterView.js +++ b/lms/static/js/student_account/views/RegisterView.js @@ -58,6 +58,7 @@ ); this.currentProvider = data.thirdPartyAuth.currentProvider || ''; this.syncLearnerProfileData = data.thirdPartyAuth.syncLearnerProfileData || false; + this.skipRegistrationOptionalCheckboxes = data.thirdPartyAuth.skipRegistrationOptionalCheckboxes || false; this.errorMessage = data.thirdPartyAuth.errorMessage || ''; this.platformName = data.platformName; this.autoSubmit = data.thirdPartyAuth.autoSubmitRegForm; @@ -156,6 +157,7 @@ fields: fields, currentProvider: this.currentProvider, syncLearnerProfileData: this.syncLearnerProfileData, + skipRegistrationOptionalCheckboxes: this.skipRegistrationOptionalCheckboxes, providers: this.providers, hasSecondaryProviders: this.hasSecondaryProviders, platformName: this.platformName, diff --git a/lms/templates/courseware/progress.html b/lms/templates/courseware/progress.html index 711fad895427..3ee4044fcbbf 100644 --- a/lms/templates/courseware/progress.html +++ b/lms/templates/courseware/progress.html @@ -16,6 +16,7 @@ from lms.djangoapps.grades.api import constants as grades_constants from openedx.core.djangolib.markup import HTML, Text from openedx.features.enterprise_support.utils import get_enterprise_learner_generic_name +from xmodule.graders import ShowCorrectness %> <% @@ -180,7 +181,7 @@

${ chapter['display_name']}

%if hide_url:

${section.display_name} - %if (total > 0 or earned > 0) and section.show_grades(staff_access): + %if (total > 0 or earned > 0) and ShowCorrectness.correctness_available(section.show_correctness, section.due, staff_access): ${_("{earned} of {total} possible points").format(earned='{:.3n}'.format(float(earned)), total='{:.3n}'.format(float(total)))} @@ -189,14 +190,14 @@

%else: ${ section.display_name} - %if (total > 0 or earned > 0) and section.show_grades(staff_access): + %if (total > 0 or earned > 0) and ShowCorrectness.correctness_available(section.show_correctness, section.due, staff_access): ${_("{earned} of {total} possible points").format(earned='{:.3n}'.format(float(earned)), total='{:.3n}'.format(float(total)))} %endif %endif - %if (total > 0 or earned > 0) and section.show_grades(staff_access): + %if (total > 0 or earned > 0) and ShowCorrectness.correctness_available(section.show_correctness, section.due, staff_access): ${"({0:.3n}/{1:.3n}) {2}".format( float(earned), float(total), percentageString )} %endif

@@ -219,7 +220,7 @@

%endif

%if len(section.problem_scores.values()) > 0: - %if section.show_grades(staff_access): + %if ShowCorrectness.correctness_available(section.show_correctness, section.due, staff_access):
${ _("Problem Scores: ") if section.graded else _("Practice Scores: ")}
%for score in section.problem_scores.values(): diff --git a/lms/templates/lti.html b/lms/templates/lti.html index 05346ec3dc40..d03b7e386740 100644 --- a/lms/templates/lti.html +++ b/lms/templates/lti.html @@ -37,6 +37,7 @@

% else: diff --git a/lms/templates/pdf_viewer.html b/lms/templates/pdf_viewer.html index a3314d54b5ac..1f4b8cd54abf 100644 --- a/lms/templates/pdf_viewer.html +++ b/lms/templates/pdf_viewer.html @@ -1,5 +1,5 @@ <%page expression_filter="h"/> - + <%namespace name='static' file='static_content.html'/> <%! from openedx.core.djangolib.js_utils import ( @@ -46,10 +46,44 @@ PDFJS.workerSrc = "${static.url('js/vendor/pdfjs/pdf.worker.js') | n, js_escaped_string}"; PDFJS.disableWorker = true; PDFJS.cMapUrl = "${static.url('css/vendor/pdfjs/cmaps/') | n, js_escaped_string}"; - PDF_URL = '${current_url | n, js_escaped_string}'; + + var PDF_URL = '${current_url | n, js_escaped_string}'; + + if (window.parent !== window) { + window.parent.postMessage({type: 'request_pdf_url'}, '*'); + + function handlePdfUrlResponse(event) { + if (event.data && event.data.type === 'pdf_url_response') { + PDF_URL = event.data.url; + + if (PDFViewerApplication.open) { + PDFViewerApplication.open(PDF_URL); + PDFViewerApplication.mouseScroll(0); + + setTimeout(function() { + if (PDFViewerApplication.pdfDocument) { + if (event.data.title) document.getElementById('titleField').textContent = event.data.title; + if (event.data.author) document.getElementById('authorField').textContent = event.data.author; + if (event.data.subject) document.getElementById('subjectField').textContent = event.data.subject; + if (event.data.keywords) document.getElementById('keywordsField').textContent = event.data.keywords; + document.getElementById('creatorField').textContent = 'edX Platform'; + } + }, 500); + } + } else if (event.data && event.data.type === 'chapter_change') { + window.parent.postMessage({type: 'request_pdf_url'}, '*'); + } + } + + window.addEventListener('message', handlePdfUrlResponse); + } + + document.addEventListener('DOMContentLoaded', function () { + PDFViewerApplication && PDFViewerApplication.open(PDF_URL); + }); - + <%static:js group='main_vendor'/> <%static:js group='application'/> @@ -347,77 +381,77 @@
-