diff --git a/.github/workflows/add-depr-ticket-to-depr-board.yml b/.github/workflows/add-depr-ticket-to-depr-board.yml
deleted file mode 100644
index 250e394abc11..000000000000
--- a/.github/workflows/add-depr-ticket-to-depr-board.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-# Run the workflow that adds new tickets that are either:
-# - labelled "DEPR"
-# - title starts with "[DEPR]"
-# - body starts with "Proposal Date" (this is the first template field)
-# to the org-wide DEPR project board
-
-name: Add newly created DEPR issues to the DEPR project board
-
-on:
- issues:
- types: [opened]
-
-jobs:
- routeissue:
- uses: openedx/.github/.github/workflows/add-depr-ticket-to-depr-board.yml@master
- secrets:
- GITHUB_APP_ID: ${{ secrets.GRAPHQL_AUTH_APP_ID }}
- GITHUB_APP_PRIVATE_KEY: ${{ secrets.GRAPHQL_AUTH_APP_PEM }}
- SLACK_BOT_TOKEN: ${{ secrets.SLACK_ISSUE_BOT_TOKEN }}
diff --git a/.github/workflows/add-remove-label-on-comment.yml b/.github/workflows/add-remove-label-on-comment.yml
deleted file mode 100644
index 0f369db7d293..000000000000
--- a/.github/workflows/add-remove-label-on-comment.yml
+++ /dev/null
@@ -1,20 +0,0 @@
-# This workflow runs when a comment is made on the ticket
-# If the comment starts with "label: " it tries to apply
-# the label indicated in rest of comment.
-# If the comment starts with "remove label: ", it tries
-# to remove the indicated label.
-# Note: Labels are allowed to have spaces and this script does
-# not parse spaces (as often a space is legitimate), so the command
-# "label: really long lots of words label" will apply the
-# label "really long lots of words label"
-
-name: Allows for the adding and removing of labels via comment
-
-on:
- issue_comment:
- types: [created]
-
-jobs:
- add_remove_labels:
- uses: openedx/.github/.github/workflows/add-remove-label-on-comment.yml@master
-
diff --git a/.github/workflows/js-tests.yml b/.github/workflows/js-tests.yml
index 94a1368e96a5..972435a820e3 100644
--- a/.github/workflows/js-tests.yml
+++ b/.github/workflows/js-tests.yml
@@ -4,7 +4,7 @@ on:
pull_request:
push:
branches:
- - master
+ - release-ulmo
jobs:
run_tests:
diff --git a/.github/workflows/lint-imports.yml b/.github/workflows/lint-imports.yml
index baf914298be2..17cc4ea0a935 100644
--- a/.github/workflows/lint-imports.yml
+++ b/.github/workflows/lint-imports.yml
@@ -4,7 +4,7 @@ on:
pull_request:
push:
branches:
- - master
+ - release-ulmo
jobs:
lint-imports:
diff --git a/.github/workflows/lockfileversion-check.yml b/.github/workflows/lockfileversion-check.yml
index 736f1f98de13..ed0b7f97de6d 100644
--- a/.github/workflows/lockfileversion-check.yml
+++ b/.github/workflows/lockfileversion-check.yml
@@ -5,7 +5,7 @@ name: Lockfile Version check
on:
push:
branches:
- - master
+ - release-ulmo
pull_request:
jobs:
diff --git a/.github/workflows/migrations-check.yml b/.github/workflows/migrations-check.yml
index cd4d09589c12..686d8d9086b0 100644
--- a/.github/workflows/migrations-check.yml
+++ b/.github/workflows/migrations-check.yml
@@ -5,7 +5,7 @@ on:
pull_request:
push:
branches:
- - master
+ - release-ulmo
jobs:
check_migrations:
diff --git a/.github/workflows/quality-checks.yml b/.github/workflows/quality-checks.yml
index 3f4cbeeb4df9..ee67e3903569 100644
--- a/.github/workflows/quality-checks.yml
+++ b/.github/workflows/quality-checks.yml
@@ -4,8 +4,7 @@ on:
pull_request:
push:
branches:
- - master
- - open-release/lilac.master
+ - release-ulmo
jobs:
run_tests:
diff --git a/.github/workflows/semgrep.yml b/.github/workflows/semgrep.yml
index 520cd23a678b..d9d32ab9d36d 100644
--- a/.github/workflows/semgrep.yml
+++ b/.github/workflows/semgrep.yml
@@ -9,7 +9,7 @@ on:
pull_request:
push:
branches:
- - master
+ - release-ulmo
jobs:
run_semgrep:
diff --git a/.github/workflows/shellcheck.yml b/.github/workflows/shellcheck.yml
index 2e5b04bcc2ff..a8df63a20d57 100644
--- a/.github/workflows/shellcheck.yml
+++ b/.github/workflows/shellcheck.yml
@@ -9,7 +9,7 @@ on:
pull_request:
push:
branches:
- - master
+ - release-ulmo
permissions:
contents: read
diff --git a/.github/workflows/static-assets-check.yml b/.github/workflows/static-assets-check.yml
index 43cb597c16d7..3a0afa76deb7 100644
--- a/.github/workflows/static-assets-check.yml
+++ b/.github/workflows/static-assets-check.yml
@@ -4,7 +4,7 @@ on:
pull_request:
push:
branches:
- - master
+ - release-ulmo
jobs:
static_assets_check:
diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml
index d8b8c26cd049..036d411fa1a0 100644
--- a/.github/workflows/unit-tests.yml
+++ b/.github/workflows/unit-tests.yml
@@ -4,7 +4,7 @@ on:
pull_request:
push:
branches:
- - master
+ - release-ulmo
concurrency:
# We only need to be running tests for the latest commit on each PR
@@ -74,6 +74,12 @@ jobs:
run: |
sudo apt-get update && sudo apt-get install libmysqlclient-dev libxmlsec1-dev lynx
+ - name: Upgrade Docker
+ run: |
+ sudo apt-get update
+ sudo apt-get install --only-upgrade docker-ce docker-ce-cli containerd.io
+ docker --version
+
# We pull this image a lot, and Dockerhub will rate limit us if we pull too often.
# This is an attempt to cache the image for better performance and to work around that.
# It will cache all pulled images, so if we add new images to this we'll need to update the key.
@@ -83,9 +89,10 @@ jobs:
key: docker-${{ runner.os }}-mongo-${{ matrix.mongo-version }}
- name: Start MongoDB
- uses: supercharge/mongodb-github-action@1.12.0
- with:
- mongodb-version: ${{ matrix.mongo-version }}
+ run: |
+ docker run -d -p 27017:27017 --name mongodb mongo:${{ matrix.mongo-version }}
+ sleep 10
+ docker ps
- name: Setup Python
uses: actions/setup-python@v5
@@ -124,26 +131,26 @@ jobs:
shell: bash
run: |
cd test_root/log
- mv pytest_warnings.json pytest_warnings_${{ matrix.shard_name }}.json
+ mv pytest_warnings.json pytest_warnings_${{ matrix.shard_name }}_${{ matrix.python-version }}_${{ matrix.django-version }}_${{ matrix.mongo-version }}_${{ matrix.os-version }}.json
- name: save pytest warnings json file
if: success()
uses: actions/upload-artifact@v4
with:
- name: pytest-warnings-json-${{ matrix.shard_name }}
+ name: pytest-warnings-json-${{ matrix.shard_name }}-${{ matrix.python-version }}-${{ matrix.django-version }}-${{ matrix.mongo-version }}-${{ matrix.os-version }}
path: |
test_root/log/pytest_warnings*.json
overwrite: true
- name: Renaming coverage data file
run: |
- mv reports/.coverage reports/${{ matrix.shard_name }}.coverage
+ mv reports/.coverage reports/${{ matrix.shard_name }}_${{ matrix.python-version }}_${{ matrix.django-version }}_${{ matrix.mongo-version }}_${{ matrix.os-version }}.coverage
- name: Upload coverage
uses: actions/upload-artifact@v4
with:
- name: coverage-${{ matrix.shard_name }}
- path: reports/${{ matrix.shard_name }}.coverage
+ name: coverage-${{ matrix.shard_name }}-${{ matrix.python-version }}-${{ matrix.django-version }}-${{ matrix.mongo-version }}-${{ matrix.os-version }}
+ path: reports/${{ matrix.shard_name }}_${{ matrix.python-version }}_${{ matrix.django-version }}_${{ matrix.mongo-version }}_${{ matrix.os-version }}.coverage
overwrite: true
collect-and-verify:
diff --git a/.github/workflows/units-test-scripts-structures-pruning.yml b/.github/workflows/units-test-scripts-structures-pruning.yml
index 14a01b592308..ef408cfe66ec 100644
--- a/.github/workflows/units-test-scripts-structures-pruning.yml
+++ b/.github/workflows/units-test-scripts-structures-pruning.yml
@@ -4,7 +4,7 @@ on:
pull_request:
push:
branches:
- - master
+ - release-ulmo
jobs:
test:
diff --git a/.github/workflows/units-test-scripts-user-retirement.yml b/.github/workflows/units-test-scripts-user-retirement.yml
index 889c43a64a48..b43bbf46b0d4 100644
--- a/.github/workflows/units-test-scripts-user-retirement.yml
+++ b/.github/workflows/units-test-scripts-user-retirement.yml
@@ -4,7 +4,7 @@ on:
pull_request:
push:
branches:
- - master
+ - release-ulmo
jobs:
test:
diff --git a/.github/workflows/update-geolite-database.yml b/.github/workflows/update-geolite-database.yml
index 484fa167a371..d9ad767ce57e 100644
--- a/.github/workflows/update-geolite-database.yml
+++ b/.github/workflows/update-geolite-database.yml
@@ -8,7 +8,7 @@ on:
branch:
description: "Target branch against which to create PR"
required: false
- default: "master"
+ default: "release-ulmo"
env:
MAXMIND_URL: "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-Country&license_key=${{ secrets.MAXMIND_LICENSE_KEY }}&suffix=tar.gz"
@@ -69,7 +69,7 @@ jobs:
- name: Create a branch, commit the code and make a PR
id: create-pr
run: |
- BRANCH="${{ github.actor }}/geoip2-bot-update-country-database-$(echo "${{ github.sha }}" | cut -c 1-7)"
+ BRANCH="${{ github.actor }}/geoip2-bot-update-country-database-${{ github.run_id }}"
git checkout -b $BRANCH
git add .
git status
@@ -79,13 +79,13 @@ jobs:
--title "Update GeoLite Database" \
--body "PR generated by workflow `${{ github.workflow }}` on behalf of @${{ github.actor }}." \
--head $BRANCH \
- --base 'master' \
- --reviewer 'feanil' \
+ --base 'release-ulmo' \
+ --reviewer 'edx/orbi-bom' \
| grep -o 'https://github.com/.*/pull/[0-9]*')
echo "PR Created: ${PR_URL}"
echo "pull-request-url=$PR_URL" >> $GITHUB_OUTPUT
env:
- GH_TOKEN: ${{ github.token }}
+ GH_TOKEN: ${{ secrets.GH_PAT_WITH_ORG }}
- name: Job summary
run: |
diff --git a/.github/workflows/upgrade-one-python-dependency.yml b/.github/workflows/upgrade-one-python-dependency.yml
index 3f9678593c25..1d8170961865 100644
--- a/.github/workflows/upgrade-one-python-dependency.yml
+++ b/.github/workflows/upgrade-one-python-dependency.yml
@@ -6,7 +6,7 @@ on:
branch:
description: "Target branch to create requirements PR against"
required: true
- default: "master"
+ default: "release-ulmo"
type: string
package:
description: "Name of package to upgrade"
diff --git a/.github/workflows/upgrade-python-requirements.yml b/.github/workflows/upgrade-python-requirements.yml
index cbb70b06b79d..90c6be00744c 100644
--- a/.github/workflows/upgrade-python-requirements.yml
+++ b/.github/workflows/upgrade-python-requirements.yml
@@ -8,7 +8,7 @@ on:
branch:
description: "Target branch to create requirements PR against"
required: true
- default: "master"
+ default: "release-ulmo"
jobs:
call-upgrade-python-requirements-workflow:
# Don't run the weekly upgrade job on forks -- it will send a weekly failure email.
diff --git a/.github/workflows/verify-dunder-init.yml b/.github/workflows/verify-dunder-init.yml
index c3248def2f33..462f0e06af0b 100644
--- a/.github/workflows/verify-dunder-init.yml
+++ b/.github/workflows/verify-dunder-init.yml
@@ -3,7 +3,7 @@ name: Verify Dunder __init__.py Files
on:
pull_request:
branches:
- - master
+ - release-ulmo
jobs:
verify_dunder_init:
diff --git a/.gitignore b/.gitignore
index a5d5252de705..5587df58faac 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,6 +11,10 @@ requirements/edx/private.in
requirements/edx/private.txt
lms/envs/private.py
cms/envs/private.py
+.venv/
+CLAUDE.md
+.claude/
+AGENTS.md
# end-noclean
### Python artifacts
diff --git a/Makefile b/Makefile
index 6c525a57b67e..66c3608af038 100644
--- a/Makefile
+++ b/Makefile
@@ -58,7 +58,8 @@ pull_translations: clean_translations ## pull translations via atlas
make pull_plugin_translations
atlas pull $(ATLAS_OPTIONS) \
translations/edx-platform/conf/locale:conf/locale \
- translations/studio-frontend/src/i18n/messages:conf/plugins-locale/studio-frontend
+ translations/studio-frontend/src/i18n/messages:conf/plugins-locale/studio-frontend \
+ $(ATLAS_EXTRA_SOURCES)
python manage.py lms compilemessages
python manage.py lms compilejsi18n
python manage.py cms compilejsi18n
diff --git a/README.rst b/README.rst
index dcd6e32c4998..a67e281f041b 100644
--- a/README.rst
+++ b/README.rst
@@ -74,7 +74,7 @@ OS:
* Ubuntu 24.04
-Interperters/Tools:
+Interpreters/Tools:
* Python 3.11
diff --git a/cms/djangoapps/contentstore/exams.py b/cms/djangoapps/contentstore/exams.py
index 6b25147c6abc..8a4ddc09425e 100644
--- a/cms/djangoapps/contentstore/exams.py
+++ b/cms/djangoapps/contentstore/exams.py
@@ -74,13 +74,13 @@ def register_exams(course_key):
# Exams in courses not using an LTI based proctoring provider should use the original definition of due_date
# from contentstore/proctoring.py. These exams are powered by the edx-proctoring plugin and not the edx-exams
# microservice.
+ is_instructor_paced = not course.self_paced
if course.proctoring_provider == 'lti_external':
- due_date = (
- timed_exam.due.isoformat() if timed_exam.due
- else (course.end.isoformat() if course.end else None)
- )
+ due_date_source = timed_exam.due if is_instructor_paced else course.end
else:
- due_date = timed_exam.due if not course.self_paced else None
+ due_date_source = timed_exam.due if is_instructor_paced else None
+
+ due_date = due_date_source.isoformat() if due_date_source else None
exams_list.append({
'course_id': str(course_key),
diff --git a/cms/djangoapps/contentstore/helpers.py b/cms/djangoapps/contentstore/helpers.py
index 2bdbabc7df8c..91236a4dade9 100644
--- a/cms/djangoapps/contentstore/helpers.py
+++ b/cms/djangoapps/contentstore/helpers.py
@@ -745,10 +745,10 @@ def _import_file_into_course(
if thumbnail_content is not None:
content.thumbnail_location = thumbnail_location
contentstore().save(content)
- return True, {clipboard_file_path: f"static/{import_path}"}
+ return True, {clipboard_file_path: filename if not import_path else f"static/{import_path}"}
elif current_file.content_digest == file_data_obj.md5_hash:
- # The file already exists and matches exactly, so no action is needed except substitutions
- return None, {clipboard_file_path: f"static/{import_path}"}
+ # The file already exists and matches exactly, so no action is needed
+ return None, {}
else:
# There is a conflict with some other file that has the same name.
return False, {}
diff --git a/cms/djangoapps/contentstore/rest_api/v1/serializers/course_waffle_flags.py b/cms/djangoapps/contentstore/rest_api/v1/serializers/course_waffle_flags.py
index 3efb7b6226d4..fde90163f803 100644
--- a/cms/djangoapps/contentstore/rest_api/v1/serializers/course_waffle_flags.py
+++ b/cms/djangoapps/contentstore/rest_api/v1/serializers/course_waffle_flags.py
@@ -11,6 +11,7 @@ class CourseWaffleFlagsSerializer(serializers.Serializer):
"""
Serializer for course waffle flags
"""
+
use_new_home_page = serializers.SerializerMethodField()
use_new_custom_pages = serializers.SerializerMethodField()
use_new_schedule_details_page = serializers.SerializerMethodField()
@@ -31,6 +32,7 @@ class CourseWaffleFlagsSerializer(serializers.Serializer):
use_react_markdown_editor = serializers.SerializerMethodField()
use_video_gallery_flow = serializers.SerializerMethodField()
enable_course_optimizer_check_prev_run_links = serializers.SerializerMethodField()
+ enable_unit_expanded_view = serializers.SerializerMethodField()
def get_course_key(self):
"""
@@ -175,3 +177,10 @@ def get_enable_course_optimizer_check_prev_run_links(self, obj):
"""
course_key = self.get_course_key()
return toggles.enable_course_optimizer_check_prev_run_links(course_key)
+
+ def get_enable_unit_expanded_view(self, obj):
+ """
+ Method to get the enable_unit_expanded_view waffle flag
+ """
+ course_key = self.get_course_key()
+ return toggles.enable_unit_expanded_view(course_key)
diff --git a/cms/djangoapps/contentstore/rest_api/v1/urls.py b/cms/djangoapps/contentstore/rest_api/v1/urls.py
index 685a81d778ce..8a94f0b0e040 100644
--- a/cms/djangoapps/contentstore/rest_api/v1/urls.py
+++ b/cms/djangoapps/contentstore/rest_api/v1/urls.py
@@ -25,6 +25,7 @@
HomePageView,
ProctoredExamSettingsView,
ProctoringErrorsView,
+ UnitComponentsView,
VideoDownloadView,
VideoUsageView,
vertical_container_children_redirect_view,
@@ -144,6 +145,11 @@
CourseWaffleFlagsView.as_view(),
name="course_waffle_flags"
),
+ re_path(
+ fr'^unit_handler/{settings.USAGE_KEY_PATTERN}$',
+ UnitComponentsView.as_view(),
+ name="unit_components"
+ ),
# Authoring API
# Do not use under v1 yet (Nov. 23). The Authoring API is still experimental and the v0 versions should be used
diff --git a/cms/djangoapps/contentstore/rest_api/v1/views/__init__.py b/cms/djangoapps/contentstore/rest_api/v1/views/__init__.py
index d4fcfd5f2e3f..25c0157904e9 100644
--- a/cms/djangoapps/contentstore/rest_api/v1/views/__init__.py
+++ b/cms/djangoapps/contentstore/rest_api/v1/views/__init__.py
@@ -14,9 +14,6 @@
from .proctoring import ProctoredExamSettingsView, ProctoringErrorsView
from .settings import CourseSettingsView
from .textbooks import CourseTextbooksView
+from .unit_handler import UnitComponentsView
from .vertical_block import ContainerHandlerView, vertical_container_children_redirect_view
-from .videos import (
- CourseVideosView,
- VideoDownloadView,
- VideoUsageView,
-)
+from .videos import CourseVideosView, VideoDownloadView, VideoUsageView
diff --git a/cms/djangoapps/contentstore/rest_api/v1/views/tests/test_course_waffle_flags.py b/cms/djangoapps/contentstore/rest_api/v1/views/tests/test_course_waffle_flags.py
index f45cc48810d6..a788ce4af3b5 100644
--- a/cms/djangoapps/contentstore/rest_api/v1/views/tests/test_course_waffle_flags.py
+++ b/cms/djangoapps/contentstore/rest_api/v1/views/tests/test_course_waffle_flags.py
@@ -38,6 +38,7 @@ class CourseWaffleFlagsViewTest(CourseTestCase):
"use_react_markdown_editor": False,
"use_video_gallery_flow": False,
"enable_course_optimizer_check_prev_run_links": False,
+ "enable_unit_expanded_view": False,
}
def setUp(self):
diff --git a/cms/djangoapps/contentstore/rest_api/v1/views/unit_handler.py b/cms/djangoapps/contentstore/rest_api/v1/views/unit_handler.py
new file mode 100644
index 000000000000..0740152e4fab
--- /dev/null
+++ b/cms/djangoapps/contentstore/rest_api/v1/views/unit_handler.py
@@ -0,0 +1,129 @@
+"""API Views for unit components handler"""
+
+import logging
+
+import edx_api_doc_tools as apidocs
+from django.http import HttpResponseBadRequest
+from opaque_keys.edx.keys import UsageKey
+from rest_framework.request import Request
+from rest_framework.response import Response
+from rest_framework.views import APIView
+
+from cms.djangoapps.contentstore.rest_api.v1.mixins import ContainerHandlerMixin
+from openedx.core.lib.api.view_utils import view_auth_classes
+from xmodule.modulestore.django import modulestore
+from xmodule.modulestore.exceptions import ItemNotFoundError
+
+log = logging.getLogger(__name__)
+
+
+@view_auth_classes(is_authenticated=True)
+class UnitComponentsView(APIView, ContainerHandlerMixin):
+ """
+ View to get all components in a unit by usage key.
+ """
+
+ @apidocs.schema(
+ parameters=[
+ apidocs.string_parameter(
+ "usage_key_string",
+ apidocs.ParameterLocation.PATH,
+ description="Unit usage key",
+ ),
+ ],
+ responses={
+ 200: "List of components in the unit",
+ 400: "Invalid usage key or unit not found.",
+ 401: "The requester is not authenticated.",
+ 404: "The requested unit does not exist.",
+ },
+ )
+ def get(self, request: Request, usage_key_string: str):
+ """
+ Get all components in a unit.
+
+ **Example Request**
+
+ GET /api/contentstore/v1/unit_handler/{usage_key_string}
+
+ **Response Values**
+
+ If the request is successful, an HTTP 200 "OK" response is returned.
+
+ The HTTP 200 response contains a dict with a list of all components
+ in the unit, including their display names, block types, and block IDs.
+
+ **Example Response**
+
+ ```json
+ {
+ "unit_id": "block-v1:edX+DemoX+Demo_Course+type@vertical+block@vertical_id",
+ "display_name": "My Unit",
+ "components": [
+ {
+ "block_id": "block-v1:edX+DemoX+Demo_Course+type@video+block@video_id",
+ "block_type": "video",
+ "display_name": "Introduction Video"
+ },
+ {
+ "block_id": "block-v1:edX+DemoX+Demo_Course+type@html+block@html_id",
+ "block_type": "html",
+ "display_name": "Text Content"
+ },
+ {
+ "block_id": "block-v1:edX+DemoX+Demo_Course+type@problem+block@problem_id",
+ "block_type": "problem",
+ "display_name": "Practice Problem"
+ }
+ ]
+ }
+ ```
+ """
+ try:
+ usage_key = UsageKey.from_string(usage_key_string)
+ except Exception as e: # pylint: disable=broad-exception-caught
+ log.error(f"Invalid usage key: {usage_key_string}, error: {str(e)}")
+ return HttpResponseBadRequest("Invalid usage key format")
+
+ try:
+ # Get the unit xblock
+ unit_xblock = modulestore().get_item(usage_key)
+
+ # Verify it's a vertical (unit)
+ if unit_xblock.category != "vertical":
+ return HttpResponseBadRequest(
+ "The provided usage key is not a unit (vertical)"
+ )
+
+ components = []
+
+ # Get all children (components) of the unit
+ if unit_xblock.has_children:
+ for child_usage_key in unit_xblock.children:
+ try:
+ child_xblock = modulestore().get_item(child_usage_key)
+ components.append(
+ {
+ "block_id": str(child_xblock.location),
+ "block_type": child_xblock.category,
+ "display_name": child_xblock.display_name_with_default,
+ }
+ )
+ except ItemNotFoundError:
+ log.warning(f"Child block not found: {child_usage_key}")
+ continue
+
+ response_data = {
+ "unit_id": str(usage_key),
+ "display_name": unit_xblock.display_name_with_default,
+ "components": components,
+ }
+
+ return Response(response_data)
+
+ except ItemNotFoundError:
+ log.error(f"Unit not found: {usage_key_string}")
+ return HttpResponseBadRequest("Unit not found")
+ except Exception as e: # pylint: disable=broad-exception-caught
+ log.error(f"Error retrieving unit components: {str(e)}")
+ return HttpResponseBadRequest(f"Error retrieving unit components: {str(e)}")
diff --git a/cms/djangoapps/contentstore/tests/test_course_listing.py b/cms/djangoapps/contentstore/tests/test_course_listing.py
index e46b493b7b39..a2b6f07d15ef 100644
--- a/cms/djangoapps/contentstore/tests/test_course_listing.py
+++ b/cms/djangoapps/contentstore/tests/test_course_listing.py
@@ -24,8 +24,10 @@
get_courses_accessible_to_user
)
from common.djangoapps.course_action_state.models import CourseRerunState
+from common.djangoapps.student.models.user import CourseAccessRole
from common.djangoapps.student.roles import (
CourseInstructorRole,
+ CourseLimitedStaffRole,
CourseStaffRole,
GlobalStaff,
OrgInstructorRole,
@@ -188,6 +190,48 @@ def test_staff_course_listing(self):
with self.assertNumQueries(2):
list(_accessible_courses_summary_iter(self.request))
+ def test_course_limited_staff_course_listing(self):
+ # Setup a new course
+ course_location = self.store.make_course_key('Org', 'CreatedCourse', 'Run')
+ CourseFactory.create(
+ org=course_location.org,
+ number=course_location.course,
+ run=course_location.run
+ )
+ course = CourseOverviewFactory.create(id=course_location, org=course_location.org)
+
+ # Add the user as a course_limited_staff on the course
+ CourseLimitedStaffRole(course.id).add_users(self.user)
+ self.assertTrue(CourseLimitedStaffRole(course.id).has_user(self.user))
+
+ # Fetch accessible courses list & verify their count
+ courses_list_by_staff, __ = get_courses_accessible_to_user(self.request)
+
+ # Limited Course Staff should not be able to list courses in Studio
+ assert len(list(courses_list_by_staff)) == 0
+
+ def test_org_limited_staff_course_listing(self):
+
+ # Setup a new course
+ course_location = self.store.make_course_key('Org', 'CreatedCourse', 'Run')
+ CourseFactory.create(
+ org=course_location.org,
+ number=course_location.course,
+ run=course_location.run
+ )
+ course = CourseOverviewFactory.create(id=course_location, org=course_location.org)
+
+ # Add a user as course_limited_staff on the org
+ # This is not possible using the course roles classes but is possible via Django admin so we
+ # insert a row into the model directly to test that scenario.
+ CourseAccessRole.objects.create(user=self.user, org=course_location.org, role=CourseLimitedStaffRole.ROLE)
+
+ # Fetch accessible courses list & verify their count
+ courses_list_by_staff, __ = get_courses_accessible_to_user(self.request)
+
+ # Limited Course Staff should not be able to list courses in Studio
+ assert len(list(courses_list_by_staff)) == 0
+
def test_get_course_list_with_invalid_course_location(self):
"""
Test getting courses with invalid course location (course deleted from modulestore).
diff --git a/cms/djangoapps/contentstore/tests/test_exams.py b/cms/djangoapps/contentstore/tests/test_exams.py
index 798b5e51fd80..823038957714 100644
--- a/cms/djangoapps/contentstore/tests/test_exams.py
+++ b/cms/djangoapps/contentstore/tests/test_exams.py
@@ -65,16 +65,21 @@ def _get_exam_due_date(self, course, sequential):
Return the expected exam due date for the exam, based on the selected course proctoring provider and the
exam due date or the course end date.
+ This is a copy of the due date computation logic in register_exams function.
+
Arguments:
* course: the course that the exam subsection is in; may have a course.end attribute
* sequential: the exam subsection; may have a sequential.due attribute
"""
+ is_instructor_paced = not course.self_paced
if course.proctoring_provider == 'lti_external':
- return sequential.due.isoformat() if sequential.due else (course.end.isoformat() if course.end else None)
- elif course.self_paced:
- return None
+ due_date_source = sequential.due if is_instructor_paced else course.end
else:
- return sequential.due
+ due_date_source = sequential.due if is_instructor_paced else None
+
+ due_date = due_date_source.isoformat() if due_date_source else None
+
+ return due_date
@ddt.data(*(tuple(base) + (extra,) for base, extra in itertools.product(
[
@@ -185,14 +190,13 @@ def test_feature_flag_off(self, mock_patch_course_exams):
def test_no_due_dates(self, is_self_paced, course_end_date, proctoring_provider, mock_patch_course_exams):
"""
Test that the the correct due date is registered for the exam when the subsection does not have a due date,
- depending on the proctoring provider.
+ depending on the proctoring provider and course pacing type.
* lti_external
- * The course end date is registered as the due date when the subsection does not have a due date for both
- self-paced and instructor-paced exams.
+ * If the course is instructor-paced, the exam due date is the subsection due date if it exists, else None.
+ * If the course is self-paced, the exam due date is the course end date if it exists, else None.
* not lti_external
- * None is registered as the due date when the subsection does not have a due date for both
- self-paced and instructor-paced exams.
+ * The exam due date is always the subsection due date if it exists, else None.
"""
self.course.self_paced = is_self_paced
self.course.end = course_end_date
@@ -222,25 +226,17 @@ def test_no_due_dates(self, is_self_paced, course_end_date, proctoring_provider,
@ddt.data(*itertools.product((True, False), ('lti_external', 'null')))
@ddt.unpack
@freeze_time('2024-01-01')
- def test_subsection_due_date_prioritized(self, is_self_paced, proctoring_provider, mock_patch_course_exams):
+ def test_subsection_due_date_prioritized_instructor_paced(
+ self,
+ is_self_paced,
+ proctoring_provider,
+ mock_patch_course_exams
+ ):
"""
- Test that the subsection due date is registered as the due date when both the subsection has a due date and the
- course has an end date for both self-paced and instructor-paced exams.
-
- Test that the the correct due date is registered for the exam when the subsection has a due date, depending on
- the proctoring provider.
-
- * lti_external
- * The subsection due date is registered as the due date when both the subsection has a due date and the
- course has an end date for both self-paced and instructor-paced exams
- * not lti_external
- * None is registered as the due date when both the subsection has a due date and the course has an end date
- for self-paced exams.
- * The subsection due date is registered as the due date when both the subsection has a due date and the
- course has an end date for instructor-paced exams.
+ Test that exam due date is computed correctly.
"""
self.course.self_paced = is_self_paced
- self.course.end = datetime(2035, 1, 1, 0, 0)
+ self.course.end = datetime(2035, 1, 1, 0, 0, tzinfo=timezone.utc)
self.course.proctoring_provider = proctoring_provider
self.course = self.update_course(self.course, 1)
@@ -260,7 +256,7 @@ def test_subsection_due_date_prioritized(self, is_self_paced, proctoring_provide
)
listen_for_course_publish(self, self.course.id)
- called_exams, called_course = mock_patch_course_exams.call_args[0]
+ called_exams, _ = mock_patch_course_exams.call_args[0]
expected_due_date = self._get_exam_due_date(self.course, sequence)
diff --git a/cms/djangoapps/contentstore/toggles.py b/cms/djangoapps/contentstore/toggles.py
index c287f8c4dbec..c3dba4a6f4a4 100644
--- a/cms/djangoapps/contentstore/toggles.py
+++ b/cms/djangoapps/contentstore/toggles.py
@@ -682,3 +682,23 @@ def enable_course_optimizer_check_prev_run_links(course_key):
Returns a boolean if previous run course optimizer feature is enabled for the given course.
"""
return ENABLE_COURSE_OPTIMIZER_CHECK_PREV_RUN_LINKS.is_enabled(course_key)
+
+
+# .. toggle_name: contentstore.enable_unit_expanded_view
+# .. toggle_implementation: CourseWaffleFlag
+# .. toggle_default: False
+# .. toggle_description: When enabled, the Unit Expanded View feature in the Course Outline is activated.
+# .. toggle_use_cases: temporary
+# .. toggle_creation_date: 2026-01-01
+# .. toggle_target_removal_date: 2026-06-01
+# .. toggle_tickets: TNL2-473
+ENABLE_UNIT_EXPANDED_VIEW = CourseWaffleFlag(
+ f"{CONTENTSTORE_NAMESPACE}.enable_unit_expanded_view", __name__
+)
+
+
+def enable_unit_expanded_view(course_key):
+ """
+ Returns a boolean if the Unit Expanded View feature is enabled for the given course.
+ """
+ return ENABLE_UNIT_EXPANDED_VIEW.is_enabled(course_key)
diff --git a/cms/djangoapps/contentstore/video_storage_handlers.py b/cms/djangoapps/contentstore/video_storage_handlers.py
index 87086c9951ac..c39970d56d20 100644
--- a/cms/djangoapps/contentstore/video_storage_handlers.py
+++ b/cms/djangoapps/contentstore/video_storage_handlers.py
@@ -13,12 +13,11 @@
import shutil
import pathlib
import zipfile
-
from contextlib import closing
from datetime import datetime, timedelta
from uuid import uuid4
-from boto.s3.connection import S3Connection
-from boto import s3
+
+import boto3
from django.conf import settings
from django.contrib.staticfiles.storage import staticfiles_storage
from django.http import FileResponse, HttpResponseNotFound, StreamingHttpResponse
@@ -55,10 +54,7 @@
from common.djangoapps.util.json_request import JsonResponse
from openedx.core.djangoapps.video_config.models import VideoTranscriptEnabledFlag
from openedx.core.djangoapps.video_config.toggles import PUBLIC_VIDEO_SHARE
-from openedx.core.djangoapps.video_pipeline.config.waffle import (
- DEPRECATE_YOUTUBE,
- ENABLE_DEVSTACK_VIDEO_UPLOADS,
-)
+from openedx.core.djangoapps.video_pipeline.config.waffle import DEPRECATE_YOUTUBE
from openedx.core.djangoapps.waffle_utils import CourseWaffleFlag
from xmodule.modulestore.django import modulestore # lint-amnesty, pylint: disable=wrong-import-order
@@ -812,7 +808,8 @@ def videos_post(course, request):
if error:
return {'error': error}, 400
- bucket = storage_service_bucket()
+ s3_client = boto3.client('s3')
+
req_files = data['files']
resp_files = []
@@ -826,7 +823,6 @@ def videos_post(course, request):
return {'error': error_msg}, 400
edx_video_id = str(uuid4())
- key = storage_service_key(bucket, file_name=edx_video_id)
metadata_list = [
('client_video_id', file_name),
@@ -846,12 +842,15 @@ def videos_post(course, request):
if transcript_preferences is not None:
metadata_list.append(('transcript_preferences', json.dumps(transcript_preferences)))
- for metadata_name, value in metadata_list:
- key.set_metadata(metadata_name, value)
- upload_url = key.generate_url(
- KEY_EXPIRATION_IN_SECONDS,
- 'PUT',
- headers={'Content-Type': req_file['content_type']}
+ upload_url = s3_client.generate_presigned_url(
+ ClientMethod='put_object',
+ Params={
+ 'Bucket': storage_service_bucket_name(),
+ 'Key': storage_service_key_name(edx_video_id),
+ 'ContentType': req_file['content_type'],
+ 'Metadata': dict(metadata_list),
+ },
+ ExpiresIn=KEY_EXPIRATION_IN_SECONDS,
)
# persist edx_video_id in VAL
@@ -869,41 +868,21 @@ def videos_post(course, request):
return {'files': resp_files}, 200
-def storage_service_bucket():
+def storage_service_bucket_name():
"""
- Returns an S3 bucket for video upload.
+ Returns name of S3 bucket to use for video upload.
"""
- if ENABLE_DEVSTACK_VIDEO_UPLOADS.is_enabled():
- params = {
- 'aws_access_key_id': settings.AWS_ACCESS_KEY_ID,
- 'aws_secret_access_key': settings.AWS_SECRET_ACCESS_KEY,
- 'security_token': settings.AWS_SECURITY_TOKEN
-
- }
- else:
- params = {
- 'aws_access_key_id': settings.AWS_ACCESS_KEY_ID,
- 'aws_secret_access_key': settings.AWS_SECRET_ACCESS_KEY
- }
-
- conn = S3Connection(**params)
-
- # We don't need to validate our bucket, it requires a very permissive IAM permission
- # set since behind the scenes it fires a HEAD request that is equivalent to get_all_keys()
- # meaning it would need ListObjects on the whole bucket, not just the path used in each
- # environment (since we share a single bucket for multiple deployments in some configurations)
- return conn.get_bucket(settings.VIDEO_UPLOAD_PIPELINE['VEM_S3_BUCKET'], validate=False)
+ return settings.VIDEO_UPLOAD_PIPELINE['VEM_S3_BUCKET']
-def storage_service_key(bucket, file_name):
+def storage_service_key_name(file_name):
"""
- Returns an S3 key to the given file in the given bucket.
+ Returns the S3 object key to be used for a given video filename.
"""
- key_name = "{}/{}".format(
+ return "{}/{}".format(
settings.VIDEO_UPLOAD_PIPELINE.get("ROOT_PATH", ""),
file_name
)
- return s3.key.Key(bucket, key_name)
def send_video_status_update(updates):
diff --git a/cms/djangoapps/contentstore/views/block.py b/cms/djangoapps/contentstore/views/block.py
index b57042085df2..238627bc6618 100644
--- a/cms/djangoapps/contentstore/views/block.py
+++ b/cms/djangoapps/contentstore/views/block.py
@@ -1,8 +1,10 @@
"""Views for blocks."""
import logging
+import re
from collections import OrderedDict
from functools import partial
+from urllib.parse import urlparse
from django.contrib.auth.decorators import login_required
from django.core.exceptions import PermissionDenied
@@ -305,6 +307,43 @@ def xblock_view_handler(request, usage_key_string, view_name):
return HttpResponse(status=406)
+def _get_safe_return_to(request):
+ """
+ Read and validate the ``returnTo`` query parameter for the XBlock edit view.
+
+ Returns the parameter value if it is a safe same-origin URL (i.e. an
+ absolute-path reference that starts with ``/`` but not ``//``), or ``None``
+ if the parameter is absent or fails validation. This prevents open-redirect
+ attacks via protocol-relative URLs such as ``//evil.com/path``.
+ """
+ return_to = request.GET.get('returnTo', '').strip()
+ if not return_to:
+ return None
+
+ if re.search(r'[\x00-\x1f\x7f]', return_to):
+ return None
+
+ if len(return_to) > 2048:
+ return None
+
+ parsed = urlparse(return_to)
+ if parsed.scheme or parsed.netloc:
+ request_origin = '{scheme}://{host}'.format(
+ scheme=request.scheme,
+ host=request.get_host(),
+ )
+ url_origin = '{scheme}://{host}'.format(
+ scheme=parsed.scheme,
+ host=parsed.netloc,
+ )
+ if request_origin != url_origin:
+ return None
+ elif not return_to.startswith('/') or return_to.startswith('//'):
+ return None
+
+ return return_to
+
+
@xframe_options_exempt
@require_http_methods(["GET"])
@login_required
@@ -313,6 +352,10 @@ def xblock_edit_view(request, usage_key_string):
Return rendered xblock edit view.
Allows editing of an XBlock specified by the usage key.
+
+ Supports an optional ``returnTo`` query parameter. When present and
+ pointing to a same-origin URL, the editor will redirect the browser to
+ that URL after the user saves or cancels instead of leaving the page blank.
"""
usage_key = usage_key_with_run(usage_key_string)
if not has_studio_read_access(request.user, usage_key.course_key):
@@ -333,6 +376,7 @@ def xblock_edit_view(request, usage_key_string):
container_handler_context.update({
"action_name": "edit",
"resources": list(hashed_resources.items()),
+ "return_to": _get_safe_return_to(request),
})
return render_to_response('container_editor.html', container_handler_context)
diff --git a/cms/djangoapps/contentstore/views/component.py b/cms/djangoapps/contentstore/views/component.py
index 34c1f465c566..e0992bff2a39 100644
--- a/cms/djangoapps/contentstore/views/component.py
+++ b/cms/djangoapps/contentstore/views/component.py
@@ -43,6 +43,12 @@
from xmodule.modulestore.django import modulestore # lint-amnesty, pylint: disable=wrong-import-order
from xmodule.modulestore.exceptions import ItemNotFoundError # lint-amnesty, pylint: disable=wrong-import-order
+try:
+ from games.toggles import is_games_xblock_enabled # pylint: disable=import-error
+except ImportError:
+ def is_games_xblock_enabled():
+ return False
+
__all__ = [
'container_handler',
'component_handler',
@@ -83,13 +89,27 @@
]
DEFAULT_ADVANCED_MODULES = [
+ 'annotatable',
+ 'done',
+ 'split_test',
+ 'freetextresponse',
'google-calendar',
'google-document',
+ 'imagemodal',
+ 'h5pxblock',
+ 'invideoquiz',
'lti_consumer',
+ 'oppia',
+ 'ubcpi',
'poll',
- 'split_test',
+ 'qualtricssurvey',
+ 'scorm',
+ 'edx_sga',
+ 'submit-and-compare',
'survey',
'word_cloud',
+ 'recommender',
+ 'library_content',
]
@@ -295,6 +315,11 @@ def create_support_legend_dict():
# by the components in the order listed in COMPONENT_TYPES.
component_types = COMPONENT_TYPES[:]
+ # Add games xblock if enabled (checked at request time)
+ if is_games_xblock_enabled():
+ component_types.append('games')
+ component_display_names['games'] = _("Games")
+
# Libraries do not support discussions, drag-and-drop, and openassessment and other libraries
component_not_supported_by_library = [
'discussion',
@@ -439,12 +464,16 @@ def create_support_legend_dict():
)
categories.add(component)
+ beta_types = BETA_COMPONENT_TYPES[:]
+ if is_games_xblock_enabled() and category == 'games':
+ beta_types.append('games')
+
component_templates.append({
"type": category,
"templates": templates_for_category,
"display_name": component_display_names[category],
"support_legend": create_support_legend_dict(),
- "beta": category in BETA_COMPONENT_TYPES,
+ "beta": category in beta_types,
})
# Libraries do not support advanced components at this time.
diff --git a/cms/djangoapps/contentstore/views/course.py b/cms/djangoapps/contentstore/views/course.py
index fa8769dc0cb9..a93a35cf1d3c 100644
--- a/cms/djangoapps/contentstore/views/course.py
+++ b/cms/djangoapps/contentstore/views/course.py
@@ -56,6 +56,7 @@
GlobalStaff,
UserBasedRole,
OrgStaffRole,
+ strict_role_checking,
)
from common.djangoapps.util.json_request import JsonResponse, JsonResponseBadRequest, expect_json
from common.djangoapps.util.string_utils import _has_non_ascii_characters
@@ -536,7 +537,9 @@ def filter_ccx(course_access):
return not isinstance(course_access.course_id, CCXLocator)
instructor_courses = UserBasedRole(request.user, CourseInstructorRole.ROLE).courses_with_role()
- staff_courses = UserBasedRole(request.user, CourseStaffRole.ROLE).courses_with_role()
+ with strict_role_checking():
+ staff_courses = UserBasedRole(request.user, CourseStaffRole.ROLE).courses_with_role()
+
all_courses = list(filter(filter_ccx, instructor_courses | staff_courses))
courses_list = []
course_keys = {}
diff --git a/cms/djangoapps/contentstore/views/tests/test_block.py b/cms/djangoapps/contentstore/views/tests/test_block.py
index 01aff3d613c1..c94038a508b5 100644
--- a/cms/djangoapps/contentstore/views/tests/test_block.py
+++ b/cms/djangoapps/contentstore/views/tests/test_block.py
@@ -76,7 +76,8 @@
from openedx.core.djangoapps.discussions.models import DiscussionsConfiguration
from openedx.core.djangoapps.content_tagging import api as tagging_api
-from ..component import component_handler, DEFAULT_ADVANCED_MODULES, get_component_templates
+from ..block import _get_safe_return_to
+from ..component import component_handler, get_component_templates
from cms.djangoapps.contentstore.xblock_storage_handlers.view_handlers import (
ALWAYS,
VisibilityState,
@@ -2974,10 +2975,23 @@ def test_basic_components(self):
self.assertGreater(len(self.get_templates_of_type("html")), 0)
self.assertGreater(len(self.get_templates_of_type("problem")), 0)
- # Check for default advanced modules
+ # Check for default advanced modules - only the ones available in test environment
advanced_templates = self.get_templates_of_type("advanced")
advanced_module_keys = [t['category'] for t in advanced_templates]
- self.assertCountEqual(advanced_module_keys, DEFAULT_ADVANCED_MODULES)
+ expected_advanced_modules = [
+ 'annotatable',
+ 'done',
+ 'google-calendar',
+ 'google-document',
+ 'lti_consumer',
+ 'poll',
+ 'split_test',
+ 'survey',
+ 'word_cloud',
+ 'recommender',
+ 'edx_sga',
+ ]
+ self.assertCountEqual(advanced_module_keys, expected_advanced_modules)
# Now fully disable video through XBlockConfiguration
XBlockConfiguration.objects.create(name="video", enabled=False)
@@ -3025,16 +3039,6 @@ def test_advanced_components(self):
"""
Test the handling of advanced component templates.
"""
- self.course.advanced_modules.append("done")
- EXPECTED_ADVANCED_MODULES_LENGTH = len(DEFAULT_ADVANCED_MODULES) + 1
- self.templates = get_component_templates(self.course)
- advanced_templates = self.get_templates_of_type("advanced")
- self.assertEqual(len(advanced_templates), EXPECTED_ADVANCED_MODULES_LENGTH)
- done_template = advanced_templates[0]
- self.assertEqual(done_template.get("category"), "done")
- self.assertEqual(done_template.get("display_name"), "Completion")
- self.assertIsNone(done_template.get("boilerplate_name", None))
-
# Verify that components are not added twice
self.course.advanced_modules.append("video")
self.course.advanced_modules.append("drag-and-drop-v2")
@@ -3045,7 +3049,6 @@ def test_advanced_components(self):
self.templates = get_component_templates(self.course)
advanced_templates = self.get_templates_of_type("advanced")
- self.assertEqual(len(advanced_templates), EXPECTED_ADVANCED_MODULES_LENGTH)
only_template = advanced_templates[0]
self.assertNotEqual(only_template.get("category"), "video")
self.assertNotEqual(only_template.get("category"), "drag-and-drop-v2")
@@ -3118,8 +3121,13 @@ def test_create_support_level_flag_off(self):
"""
XBlockStudioConfigurationFlag.objects.create(enabled=False)
self.course.advanced_modules.extend(["annotatable", "done"])
- expected_xblocks = ["Annotation", "Completion"] + self.default_advanced_modules_titles
- self._verify_advanced_xblocks(expected_xblocks, [True] * len(expected_xblocks))
+ # Get actual templates to determine count dynamically
+ templates = get_component_templates(self.course)
+ advanced_templates = templates[-1]["templates"]
+ expected_count = len(advanced_templates)
+ # Verify all advanced templates have support_level=True
+ for template in advanced_templates:
+ self.assertTrue(template["support_level"])
def test_xblock_masquerading_as_problem(self):
"""
@@ -4628,3 +4636,93 @@ def test_xblock_edit_view_contains_resources(self):
self.assertGreater(len(resource_links), 0, f"No CSS resources found in HTML. Found: {resource_links}")
self.assertGreater(len(script_sources), 0, f"No JS resources found in HTML. Found: {script_sources}")
+
+
+class TestGetSafeReturnTo(TestCase):
+ """
+ Tests for _get_safe_return_to validation.
+ """
+
+ def setUp(self):
+ super().setUp()
+ self.factory = RequestFactory()
+
+ def _make_request(self, return_to=None):
+ """Build a GET request with an optional returnTo query parameter."""
+ url = '/dummy'
+ if return_to is not None:
+ url = f'/dummy?returnTo={return_to}'
+ return self.factory.get(url)
+
+ # -- valid inputs --------------------------------------------------------
+
+ def test_valid_relative_path(self):
+ request = self._make_request('/course/123')
+ self.assertEqual(_get_safe_return_to(request), '/course/123')
+
+ def test_valid_root_path(self):
+ request = self._make_request('/')
+ self.assertEqual(_get_safe_return_to(request), '/')
+
+ def test_valid_relative_path_with_query_string(self):
+ request = self.factory.get('/dummy', {'returnTo': '/course/123?tab=outline'})
+ self.assertEqual(_get_safe_return_to(request), '/course/123?tab=outline')
+
+ def test_valid_absolute_url_same_origin(self):
+ request = self.factory.get('/dummy', {'returnTo': 'http://testserver/course/123'})
+ self.assertEqual(_get_safe_return_to(request), 'http://testserver/course/123')
+
+ # -- empty / missing values ----------------------------------------------
+
+ def test_missing_parameter(self):
+ request = self.factory.get('/dummy')
+ self.assertIsNone(_get_safe_return_to(request))
+
+ def test_empty_string(self):
+ request = self._make_request('')
+ self.assertIsNone(_get_safe_return_to(request))
+
+ def test_whitespace_only(self):
+ request = self.factory.get('/dummy', {'returnTo': ' '})
+ self.assertIsNone(_get_safe_return_to(request))
+
+ # -- protocol-relative / different origin --------------------------------
+
+ def test_protocol_relative_url_rejected(self):
+ request = self._make_request('//evil.com/path')
+ self.assertIsNone(_get_safe_return_to(request))
+
+ def test_absolute_url_different_origin_rejected(self):
+ request = self.factory.get('/dummy', {'returnTo': 'https://evil.com/steal'})
+ self.assertIsNone(_get_safe_return_to(request))
+
+ def test_absolute_url_different_scheme_rejected(self):
+ request = self.factory.get('/dummy', {'returnTo': 'https://testserver/course'})
+ self.assertIsNone(_get_safe_return_to(request))
+
+ # -- relative paths that don't start with / ------------------------------
+
+ def test_bare_relative_path_rejected(self):
+ request = self._make_request('course/123')
+ self.assertIsNone(_get_safe_return_to(request))
+
+ # -- control characters --------------------------------------------------
+
+ def test_null_byte_rejected(self):
+ request = self.factory.get('/dummy', {'returnTo': '/course/\x00'})
+ self.assertIsNone(_get_safe_return_to(request))
+
+ def test_newline_rejected(self):
+ request = self.factory.get('/dummy', {'returnTo': '/course/\n/path'})
+ self.assertIsNone(_get_safe_return_to(request))
+
+ def test_tab_character_rejected(self):
+ request = self.factory.get('/dummy', {'returnTo': '/course/\t/path'})
+ self.assertIsNone(_get_safe_return_to(request))
+
+ # -- length limit --------------------------------------------------------
+
+ def test_url_over_2048_chars_rejected(self):
+ long_url = '/' + 'a' * 2048
+ request = self.factory.get('/dummy', {'returnTo': long_url})
+ self.assertIsNone(_get_safe_return_to(request))
diff --git a/cms/djangoapps/contentstore/views/tests/test_videos.py b/cms/djangoapps/contentstore/views/tests/test_videos.py
index 2d17229f59c1..cf2aefb9935d 100644
--- a/cms/djangoapps/contentstore/views/tests/test_videos.py
+++ b/cms/djangoapps/contentstore/views/tests/test_videos.py
@@ -9,13 +9,12 @@
from contextlib import contextmanager
from datetime import datetime
from io import StringIO
-from unittest.mock import Mock, patch
+from unittest.mock import Mock, call, patch
import dateutil.parser
from common.djangoapps.student.tests.factories import UserFactory
import ddt
import pytz
-from django.test import TestCase
from django.conf import settings
from django.test.utils import override_settings
from django.urls import reverse
@@ -33,10 +32,7 @@
from cms.djangoapps.contentstore.tests.utils import CourseTestCase
from cms.djangoapps.contentstore.utils import reverse_course_url
from openedx.core.djangoapps.profile_images.tests.helpers import make_image_file
-from openedx.core.djangoapps.video_pipeline.config.waffle import (
- DEPRECATE_YOUTUBE,
- ENABLE_DEVSTACK_VIDEO_UPLOADS,
-)
+from openedx.core.djangoapps.video_pipeline.config.waffle import DEPRECATE_YOUTUBE
from openedx.core.djangoapps.waffle_utils.models import WaffleFlagCourseOverrideModel
from xmodule.modulestore.django import modulestore
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
@@ -52,11 +48,13 @@
TranscriptProvider,
StatusDisplayStrings,
convert_video_status,
- storage_service_bucket,
- storage_service_key,
- PUBLIC_VIDEO_SHARE
+ PUBLIC_VIDEO_SHARE,
)
+# Constant defined to make it clear when we're grabbing the kwargs from a
+# unittest.mock.call (which is a list of [args, kwargs]).
+CALL_KW = 1
+
class VideoUploadTestBase:
"""
@@ -167,6 +165,40 @@ def _get_previous_upload(self, edx_video_id):
if video["edx_video_id"] == edx_video_id
)
+ @contextmanager
+ def patch_presign_url(self, files):
+ """
+ Decorator that patches boto3 to mock out S3 URL presigning.
+
+ Assumes that the only client in use is S3, and that only the presigning
+ method will be called. Makes assertions about what calls were made.
+
+ Decorator yields a result dictionary that will be populated *after* the
+ context closes. The one key is "calls", a list of call objects to the mock.
+
+ Arguments:
+ files: List of files to use for upload (dict of file_name and content_type)
+ """
+ mock_gen_url = Mock(side_effect=[
+ 'http://example.com/url_{}'.format(file_info['file_name'])
+ for file_info in files
+ ])
+ mock_s3_client = Mock()
+ mock_s3_client.generate_presigned_url = mock_gen_url
+ with patch(
+ 'cms.djangoapps.contentstore.video_storage_handlers.boto3.client',
+ return_value=mock_s3_client
+ ) as mock_boto_client:
+ results = {}
+ try:
+ yield results # run wrapped block
+ finally:
+ results['calls'] = mock_gen_url.call_args_list
+
+ # Ensure that we're only trying to load the S3 client
+ for c in mock_boto_client.call_args_list:
+ self.assertEqual(c, call('s3'))
+
class VideoStudioAccessTestsMixin:
"""
@@ -215,10 +247,7 @@ class VideoUploadPostTestsMixin:
"""
Shared test cases for video post tests.
"""
- @override_settings(AWS_ACCESS_KEY_ID='test_key_id', AWS_SECRET_ACCESS_KEY='test_secret')
- @patch('boto.s3.key.Key')
- @patch('cms.djangoapps.contentstore.video_storage_handlers.S3Connection')
- def test_post_success(self, mock_conn, mock_key):
+ def test_post_success(self):
files = [
{
'file_name': 'first.mp4',
@@ -238,63 +267,42 @@ def test_post_success(self, mock_conn, mock_key):
},
]
- bucket = Mock()
- mock_conn.return_value = Mock(get_bucket=Mock(return_value=bucket))
- mock_key_instances = [
- Mock(
- generate_url=Mock(
- return_value='http://example.com/url_{}'.format(file_info['file_name'])
- )
+ with self.patch_presign_url(files) as presign_results:
+ response = self.client.post(
+ self.url,
+ json.dumps({'files': files}),
+ content_type='application/json'
)
- for file_info in files
- ]
- # If extra calls are made, return a dummy
- mock_key.side_effect = mock_key_instances + [Mock()]
-
- response = self.client.post(
- self.url,
- json.dumps({'files': files}),
- content_type='application/json'
- )
self.assertEqual(response.status_code, 200)
response_obj = json.loads(response.content.decode('utf-8'))
- mock_conn.assert_called_once_with(
- aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
- aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY
- )
self.assertEqual(len(response_obj['files']), len(files))
- self.assertEqual(mock_key.call_count, len(files))
+ presign_calls = presign_results['calls']
+ self.assertEqual(len(presign_calls), len(files))
for i, file_info in enumerate(files):
- # Ensure Key was set up correctly and extract id
- key_call_args, __ = mock_key.call_args_list[i]
- self.assertEqual(key_call_args[0], bucket)
+ call_kwargs = presign_calls[i][CALL_KW]
+
+ self.assertEqual(call_kwargs['ClientMethod'], 'put_object')
path_match = re.match(
(
settings.VIDEO_UPLOAD_PIPELINE['ROOT_PATH'] +
'/([a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12})$'
),
- key_call_args[1]
+ call_kwargs['Params']['Key']
)
self.assertIsNotNone(path_match)
video_id = path_match.group(1)
- mock_key_instance = mock_key_instances[i]
-
- mock_key_instance.set_metadata.assert_any_call(
- 'course_video_upload_token',
- self.test_token
- )
- mock_key_instance.set_metadata.assert_any_call(
- 'client_video_id',
- file_info['file_name']
- )
- mock_key_instance.set_metadata.assert_any_call('course_key', str(self.course.id))
- mock_key_instance.generate_url.assert_called_once_with(
- KEY_EXPIRATION_IN_SECONDS,
- 'PUT',
- headers={'Content-Type': file_info['content_type']}
+ self.assertEqual(
+ call_kwargs['Params']['Metadata'],
+ {
+ 'course_video_upload_token': self.test_token,
+ 'client_video_id': file_info['file_name'],
+ 'course_key': str(self.course.id),
+ }
)
+ self.assertEqual(call_kwargs['Params']['ContentType'], file_info['content_type'])
+ self.assertEqual(call_kwargs['ExpiresIn'], KEY_EXPIRATION_IN_SECONDS)
# Ensure VAL was updated
val_info = get_video_info(video_id)
@@ -307,7 +315,7 @@ def test_post_success(self, mock_conn, mock_key):
# Ensure response is correct
response_file = response_obj['files'][i]
self.assertEqual(response_file['file_name'], file_info['file_name'])
- self.assertEqual(response_file['upload_url'], mock_key_instance.generate_url())
+ self.assertEqual(response_file['upload_url'], f"http://example.com/url_{file_info['file_name']}")
def test_post_non_json(self):
response = self.client.post(self.url, {"files": []})
@@ -479,9 +487,6 @@ def test_get_html_paginated(self):
self.assertEqual(response.status_code, 200)
self.assertContains(response, 'video_upload_pagination')
- @override_settings(AWS_ACCESS_KEY_ID="test_key_id", AWS_SECRET_ACCESS_KEY="test_secret")
- @patch("boto.s3.key.Key")
- @patch("cms.djangoapps.contentstore.video_storage_handlers.S3Connection")
@ddt.data(
(
[
@@ -511,28 +516,17 @@ def test_get_html_paginated(self):
)
)
@ddt.unpack
- def test_video_supported_file_formats(self, files, expected_status, mock_conn, mock_key):
+ def test_video_supported_file_formats(self, files, expected_status):
"""
Test that video upload works correctly against supported and unsupported file formats.
"""
- mock_conn.get_bucket = Mock()
- mock_key_instances = [
- Mock(
- generate_url=Mock(
- return_value="http://example.com/url_{}".format(file_info["file_name"])
- )
- )
- for file_info in files
- ]
- # If extra calls are made, return a dummy
- mock_key.side_effect = mock_key_instances + [Mock()]
-
# Check supported formats
- response = self.client.post(
- self.url,
- json.dumps({"files": files}),
- content_type="application/json"
- )
+ with self.patch_presign_url(files):
+ response = self.client.post(
+ self.url,
+ json.dumps({"files": files}),
+ content_type="application/json"
+ )
self.assertEqual(response.status_code, expected_status)
response = json.loads(response.content.decode('utf-8'))
@@ -542,19 +536,12 @@ def test_video_supported_file_formats(self, files, expected_status, mock_conn, m
self.assertIn('error', response)
self.assertEqual(response['error'], "Request 'files' entry contain unsupported content_type")
- @override_settings(AWS_ACCESS_KEY_ID='test_key_id', AWS_SECRET_ACCESS_KEY='test_secret')
- @patch('cms.djangoapps.contentstore.video_storage_handlers.S3Connection')
- def test_upload_with_non_ascii_charaters(self, mock_conn):
+ def test_upload_with_non_ascii_characters(self):
"""
Test that video uploads throws error message when file name contains special characters.
"""
- mock_conn.get_bucket = Mock()
file_name = 'test\u2019_file.mp4'
files = [{'file_name': file_name, 'content_type': 'video/mp4'}]
-
- bucket = Mock()
- mock_conn.return_value = Mock(get_bucket=Mock(return_value=bucket))
-
response = self.client.post(
self.url,
json.dumps({'files': files}),
@@ -564,67 +551,24 @@ def test_upload_with_non_ascii_charaters(self, mock_conn):
response = json.loads(response.content.decode('utf-8'))
self.assertEqual(response['error'], 'The file name for %s must contain only ASCII characters.' % file_name)
- @override_settings(AWS_ACCESS_KEY_ID='test_key_id', AWS_SECRET_ACCESS_KEY='test_secret', AWS_SECURITY_TOKEN='token')
- @patch('boto.s3.key.Key')
- @patch('cms.djangoapps.contentstore.video_storage_handlers.S3Connection')
- @override_waffle_flag(ENABLE_DEVSTACK_VIDEO_UPLOADS, active=True)
- def test_devstack_upload_connection(self, mock_conn, mock_key):
- files = [{'file_name': 'first.mp4', 'content_type': 'video/mp4'}]
- mock_conn.get_bucket = Mock()
- mock_key_instances = [
- Mock(
- generate_url=Mock(
- return_value='http://example.com/url_{}'.format(file_info['file_name'])
- )
- )
- for file_info in files
- ]
- mock_key.side_effect = mock_key_instances
- response = self.client.post(
- self.url,
- json.dumps({'files': files}),
- content_type='application/json'
- )
-
- self.assertEqual(response.status_code, 200)
- mock_conn.assert_called_once_with(
- aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
- aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
- security_token=settings.AWS_SECURITY_TOKEN
- )
-
- @patch('boto.s3.key.Key')
- @patch('cms.djangoapps.contentstore.video_storage_handlers.S3Connection')
- def test_send_course_to_vem_pipeline(self, mock_conn, mock_key):
+ def test_send_course_to_vem_pipeline(self):
"""
Test that uploads always go to VEM S3 bucket by default.
"""
- mock_conn.get_bucket = Mock()
files = [{'file_name': 'first.mp4', 'content_type': 'video/mp4'}]
- mock_key_instances = [
- Mock(
- generate_url=Mock(
- return_value='http://example.com/url_{}'.format(file_info['file_name'])
- )
+ with self.patch_presign_url(files) as presign_results:
+ response = self.client.post(
+ self.url,
+ json.dumps({'files': files}),
+ content_type='application/json'
)
- for file_info in files
- ]
- mock_key.side_effect = mock_key_instances
-
- response = self.client.post(
- self.url,
- json.dumps({'files': files}),
- content_type='application/json'
- )
self.assertEqual(response.status_code, 200)
- mock_conn.return_value.get_bucket.assert_called_once_with(
- settings.VIDEO_UPLOAD_PIPELINE['VEM_S3_BUCKET'], validate=False # pylint: disable=unsubscriptable-object
+ self.assertEqual(
+ presign_results['calls'][0][CALL_KW]['Params']['Bucket'],
+ settings.VIDEO_UPLOAD_PIPELINE['VEM_S3_BUCKET']
)
- @override_settings(AWS_ACCESS_KEY_ID='test_key_id', AWS_SECRET_ACCESS_KEY='test_secret')
- @patch('boto.s3.key.Key')
- @patch('cms.djangoapps.contentstore.video_storage_handlers.S3Connection')
@ddt.data(
{
'global_waffle': True,
@@ -642,52 +586,26 @@ def test_send_course_to_vem_pipeline(self, mock_conn, mock_key):
'expect_token': True
}
)
- def test_video_upload_token_in_meta(self, data, mock_conn, mock_key):
+ def test_video_upload_token_in_meta(self, data):
"""
Test video upload token in s3 metadata.
"""
- @contextmanager
- def proxy_manager(manager, ignore_manager):
- """
- This acts as proxy to the original manager in the arguments given
- the original manager is not set to be ignored.
- """
- if ignore_manager:
- yield
- else:
- with manager:
- yield
-
file_data = {
'file_name': 'first.mp4',
'content_type': 'video/mp4',
}
- mock_conn.get_bucket = Mock()
- mock_key_instance = Mock(
- generate_url=Mock(
- return_value='http://example.com/url_{}'.format(file_data['file_name'])
- )
- )
- # If extra calls are made, return a dummy
- mock_key.side_effect = [mock_key_instance]
-
- # expected args to be passed to `set_metadata`.
- expected_args = ('course_video_upload_token', self.test_token)
-
with patch.object(WaffleFlagCourseOverrideModel, 'override_value', return_value=data['course_override']):
with override_waffle_flag(DEPRECATE_YOUTUBE, active=data['global_waffle']):
- response = self.client.post(
- self.url,
- json.dumps({'files': [file_data]}),
- content_type='application/json'
- )
+ with self.patch_presign_url([file_data]) as presign_results:
+ response = self.client.post(
+ self.url,
+ json.dumps({'files': [file_data]}),
+ content_type='application/json'
+ )
self.assertEqual(response.status_code, 200)
- with proxy_manager(self.assertRaises(AssertionError), data['expect_token']):
- # if we're not expecting token then following should raise assertion error and
- # if we're expecting token then we will be able to find the call to set the token
- # in s3 metadata.
- mock_key_instance.set_metadata.assert_any_call(*expected_args)
+ actual_token = presign_results['calls'][0][CALL_KW]['Params']['Metadata'].get('course_video_upload_token')
+ self.assertEqual(actual_token, self.test_token if data['expect_token'] else None)
def _assert_video_removal(self, url, edx_video_id, deleted_videos):
"""
@@ -1460,47 +1378,37 @@ def test_remove_transcript_preferences_not_found(self):
)
@ddt.unpack
@override_settings(AWS_ACCESS_KEY_ID='test_key_id', AWS_SECRET_ACCESS_KEY='test_secret')
- @patch('boto.s3.key.Key')
- @patch('cms.djangoapps.contentstore.video_storage_handlers.S3Connection')
@patch('cms.djangoapps.contentstore.video_storage_handlers.get_transcript_preferences')
def test_transcript_preferences_metadata(self, transcript_preferences, is_video_transcript_enabled,
- mock_transcript_preferences, mock_conn, mock_key):
+ mock_transcript_preferences):
"""
Tests that transcript preference metadata is only set if it is video transcript feature is enabled and
transcript preferences are already stored in the system.
"""
file_name = 'test-video.mp4'
- request_data = {'files': [{'file_name': file_name, 'content_type': 'video/mp4'}]}
+ files = [{'file_name': file_name, 'content_type': 'video/mp4'}]
mock_transcript_preferences.return_value = transcript_preferences
- bucket = Mock()
- mock_conn.return_value = Mock(get_bucket=Mock(return_value=bucket))
- mock_key_instance = Mock(
- generate_url=Mock(
- return_value=f'http://example.com/url_{file_name}'
- )
- )
- # If extra calls are made, return a dummy
- mock_key.side_effect = [mock_key_instance] + [Mock()]
-
videos_handler_url = reverse_course_url('videos_handler', self.course.id)
with patch(
'openedx.core.djangoapps.video_config.models.VideoTranscriptEnabledFlag.feature_enabled'
) as video_transcript_feature:
video_transcript_feature.return_value = is_video_transcript_enabled
- response = self.client.post(videos_handler_url, json.dumps(request_data), content_type='application/json')
+ with self.patch_presign_url(files) as presign_results:
+ response = self.client.post(
+ videos_handler_url, json.dumps({'files': files}),
+ content_type='application/json',
+ )
self.assertEqual(response.status_code, 200)
- # Ensure `transcript_preferences` was set up in Key correctly if sent through request.
+ # Ensure `transcript_preferences` was set in metadata correctly if sent through request.
+ actual_value = presign_results['calls'][0][CALL_KW]['Params']['Metadata'].get('transcript_preferences')
if is_video_transcript_enabled and transcript_preferences:
- mock_key_instance.set_metadata.assert_any_call('transcript_preferences', json.dumps(transcript_preferences))
+ self.assertEqual(actual_value, json.dumps(transcript_preferences))
else:
- with self.assertRaises(AssertionError):
- mock_key_instance.set_metadata.assert_any_call(
- 'transcript_preferences', json.dumps(transcript_preferences)
- )
+ self.assertEqual(actual_value, None)
@patch.dict("django.conf.settings.FEATURES", {"ENABLE_VIDEO_UPLOAD_PIPELINE": True})
@@ -1644,29 +1552,6 @@ def _test_video_feature(self, flag, key, override_fn, is_enabled):
self.assertEqual(response.json()[key], is_enabled)
-class GetStorageBucketTestCase(TestCase):
- """ This test just check that connection works and returns the bucket.
- It does not involve any mocking and triggers errors if has any import issue.
- """
- @override_settings(AWS_ACCESS_KEY_ID='test_key_id', AWS_SECRET_ACCESS_KEY='test_secret')
- @override_settings(VIDEO_UPLOAD_PIPELINE={
- "VEM_S3_BUCKET": "vem_test_bucket", "BUCKET": "test_bucket", "ROOT_PATH": "test_root"
- })
- def test_storage_bucket(self):
- """ get bucket and generate url. It will not hit actual s3."""
- bucket = storage_service_bucket()
- edx_video_id = 'dummy_video'
- key = storage_service_key(bucket, file_name=edx_video_id)
- upload_url = key.generate_url(
- KEY_EXPIRATION_IN_SECONDS,
- 'PUT',
- headers={'Content-Type': 'mp4'}
- )
-
- self.assertIn("https://vem_test_bucket.s3.amazonaws.com:443/test_root/", upload_url)
- self.assertIn(edx_video_id, upload_url)
-
-
class CourseYoutubeEdxVideoIds(ModuleStoreTestCase):
"""
This test checks youtube videos in a course
diff --git a/cms/djangoapps/contentstore/views/videos.py b/cms/djangoapps/contentstore/views/videos.py
index 2eac141b9c9e..499c73c29a23 100644
--- a/cms/djangoapps/contentstore/views/videos.py
+++ b/cms/djangoapps/contentstore/views/videos.py
@@ -23,8 +23,6 @@
videos_index_html as videos_index_html_source_function,
videos_index_json as videos_index_json_source_function,
videos_post as videos_post_source_function,
- storage_service_bucket as storage_service_bucket_source_function,
- storage_service_key as storage_service_key_source_function,
send_video_status_update as send_video_status_update_source_function,
is_status_update_request as is_status_update_request_source_function,
get_course_youtube_edx_video_ids,
@@ -212,20 +210,6 @@ def videos_post(course, request):
return videos_post_source_function(course, request)
-def storage_service_bucket():
- """
- Exposes helper method without breaking existing bindings/dependencies
- """
- return storage_service_bucket_source_function()
-
-
-def storage_service_key(bucket, file_name):
- """
- Exposes helper method without breaking existing bindings/dependencies
- """
- return storage_service_key_source_function(bucket, file_name)
-
-
def send_video_status_update(updates):
"""
Exposes helper method without breaking existing bindings/dependencies
diff --git a/cms/static/images/large-games-icon.svg b/cms/static/images/large-games-icon.svg
new file mode 100644
index 000000000000..d23b38c7748d
--- /dev/null
+++ b/cms/static/images/large-games-icon.svg
@@ -0,0 +1,10 @@
+
diff --git a/cms/static/js/spec/views/modals/edit_xblock_spec.js b/cms/static/js/spec/views/modals/edit_xblock_spec.js
index 06b4413784cb..0855a850b97d 100644
--- a/cms/static/js/spec/views/modals/edit_xblock_spec.js
+++ b/cms/static/js/spec/views/modals/edit_xblock_spec.js
@@ -185,6 +185,68 @@ describe('EditXBlockModal', function() {
});
});
+ describe('isSafeReturnTo', function() {
+ var isSafeReturnTo = EditXBlockModal.isSafeReturnTo;
+
+ // -- valid inputs ---------------------------------------------------
+
+ it('accepts a root-relative path', function() {
+ expect(isSafeReturnTo('/course/123')).toBe(true);
+ });
+
+ it('accepts the root path', function() {
+ expect(isSafeReturnTo('/')).toBe(true);
+ });
+
+ it('accepts a root-relative path with query string', function() {
+ expect(isSafeReturnTo('/course/123?tab=outline')).toBe(true);
+ });
+
+ it('accepts an absolute URL with the same origin', function() {
+ expect(isSafeReturnTo(window.location.origin + '/course/123')).toBe(true);
+ });
+
+ // -- empty / missing values -----------------------------------------
+
+ it('rejects an empty string', function() {
+ expect(isSafeReturnTo('')).toBe(false);
+ });
+
+ it('rejects null', function() {
+ expect(isSafeReturnTo(null)).toBe(false);
+ });
+
+ it('rejects undefined', function() {
+ expect(isSafeReturnTo(undefined)).toBe(false);
+ });
+
+ it('rejects a non-string value', function() {
+ expect(isSafeReturnTo(42)).toBe(false);
+ });
+
+ // -- protocol-relative / different origin ---------------------------
+
+ it('rejects protocol-relative URLs', function() {
+ expect(isSafeReturnTo('//evil.com/path')).toBe(false);
+ });
+
+ it('rejects an absolute URL with a different origin', function() {
+ expect(isSafeReturnTo('https://evil.com/steal')).toBe(false);
+ });
+
+ // -- relative paths without leading / -------------------------------
+
+ it('rejects a bare relative path', function() {
+ expect(isSafeReturnTo('course/123')).toBe(false);
+ });
+
+ // -- javascript: scheme ---------------------------------------------
+
+ it('rejects javascript: scheme', function() {
+ expect(isSafeReturnTo('javascript:alert(1)')).toBe(false); // eslint-disable-line no-script-url
+ });
+ });
+
describe('XModule Editor (settings only)', function() {
var mockXModuleEditorHtml;
diff --git a/cms/static/js/views/modals/edit_xblock.js b/cms/static/js/views/modals/edit_xblock.js
index 586d27d8b284..aef409c56150 100644
--- a/cms/static/js/views/modals/edit_xblock.js
+++ b/cms/static/js/views/modals/edit_xblock.js
@@ -8,6 +8,27 @@ define(['jquery', 'underscore', 'backbone', 'gettext', 'js/views/modals/base_mod
function($, _, Backbone, gettext, BaseModal, ViewUtils, XBlockViewUtils, XBlockEditorView) {
'use strict';
+ /**
+ * Returns true when ``url`` is safe to use as a same-origin redirect
+ * destination. Accepted values are:
+ * - Root-relative paths beginning with '/' but NOT '//' (protocol-
+ * relative URLs such as //evil.com would bypass the check).
+ * - Absolute URLs whose origin matches the current page's origin.
+ */
+ function isSafeReturnTo(url) {
+ if (typeof url !== 'string' || url.length === 0) {
+ return false;
+ }
+ if (url.charAt(0) === '/' && url.charAt(1) !== '/') {
+ return true;
+ }
+ try {
+ return new URL(url).origin === window.location.origin;
+ } catch (e) {
+ return false;
+ }
+ }
+
var EditXBlockModal = BaseModal.extend({
events: _.extend({}, BaseModal.prototype.events, {
'click .action-save': 'save',
@@ -234,6 +255,13 @@ function($, _, Backbone, gettext, BaseModal, ViewUtils, XBlockViewUtils, XBlockE
console.error(e);
}
+ var returnTo = this.editOptions && this.editOptions.returnTo;
+ if (returnTo && isSafeReturnTo(returnTo)) {
+ this.hide();
+ window.location.href = returnTo;
+ return;
+ }
+
var refresh = this.editOptions.refresh;
this.hide();
if (refresh) {
@@ -241,6 +269,18 @@ function($, _, Backbone, gettext, BaseModal, ViewUtils, XBlockViewUtils, XBlockE
}
},
+ cancel: function(event) {
+ if (event) {
+ event.preventDefault();
+ event.stopPropagation();
+ }
+ this.hide();
+ var returnTo = this.editOptions && this.editOptions.returnTo;
+ if (returnTo && isSafeReturnTo(returnTo)) {
+ window.location.href = returnTo;
+ }
+ },
+
hide: function() {
// Notify child views to stop listening events
Backbone.trigger('xblock:editorModalHidden');
@@ -296,5 +336,8 @@ function($, _, Backbone, gettext, BaseModal, ViewUtils, XBlockViewUtils, XBlockE
});
+ // Expose for unit testing.
+ EditXBlockModal.isSafeReturnTo = isSafeReturnTo;
+
return EditXBlockModal;
});
diff --git a/cms/static/js/views/pages/container.js b/cms/static/js/views/pages/container.js
index d50f6b4bbe4a..f8e67a47c074 100644
--- a/cms/static/js/views/pages/container.js
+++ b/cms/static/js/views/pages/container.js
@@ -512,6 +512,7 @@ function($, _, Backbone, gettext, BasePage,
if((useNewTextEditor === 'True' && blockType === 'html')
|| (useNewVideoEditor === 'True' && blockType === 'video')
|| (useNewProblemEditor === 'True' && blockType === 'problem')
+ || (blockType === 'games') || (blockType === 'invideoquiz')
) {
var destinationUrl = primaryHeader.attr('authoring_MFE_base_url')
+ '/' + blockType
@@ -1183,7 +1184,9 @@ function($, _, Backbone, gettext, BasePage,
// open mfe editors for new blocks only and not for content imported from libraries
if(!data.hasOwnProperty('upstreamRef') && ((useNewTextEditor === 'True' && blockType.includes('html'))
|| (useNewVideoEditor === 'True' && blockType.includes('video'))
- || (useNewProblemEditor === 'True' && blockType.includes('problem')))
+ || (useNewProblemEditor === 'True' && blockType.includes('problem'))
+ || blockType.includes('games')
+ || blockType.includes('invideoquiz'))
){
if (this.options.isIframeEmbed && (this.isSplitTestContentPage || this.isVerticalContentPage)) {
return this.postMessageToParent({
diff --git a/cms/static/sass/assets/_graphics.scss b/cms/static/sass/assets/_graphics.scss
index afb830d5dd71..58141c445a75 100644
--- a/cms/static/sass/assets/_graphics.scss
+++ b/cms/static/sass/assets/_graphics.scss
@@ -80,3 +80,9 @@
height: ($baseline*3);
background: url('#{$static-path}/images/large-itembank-icon.png') center no-repeat;
}
+
+.large-games-icon {
+ display: inline-block;
+ width: ($baseline*3);
+ height: ($baseline*3);
+ background: url('#{$static-path}/images/large-games-icon.svg') center no-repeat; }
diff --git a/cms/static/sass/elements/_modules.scss b/cms/static/sass/elements/_modules.scss
index 1e9f52691fc8..f1c871ba1d95 100644
--- a/cms/static/sass/elements/_modules.scss
+++ b/cms/static/sass/elements/_modules.scss
@@ -163,8 +163,8 @@
display: inline-block;
color: $uxpl-green-base;
- background-color: theme-color("inverse");
- border-color: theme-color("inverse");
+ background-color: $white;
+ border-color: $white;
border-radius: 3px;
font-size: 90%;
diff --git a/cms/templates/container_editor.html b/cms/templates/container_editor.html
index e7585d7b9664..900dac80abf9 100644
--- a/cms/templates/container_editor.html
+++ b/cms/templates/container_editor.html
@@ -119,12 +119,13 @@
function (XBlockInfo, EditXBlockModal) {
var decodedActionName = '${action_name|n, decode.utf8}';
var encodedXBlockDetails = ${xblock_info | n, dump_js_escaped_json};
+ var returnTo = '${return_to or "" | n, js_escaped_string}';
if (decodedActionName === 'edit') {
var editXBlockModal = new EditXBlockModal();
var xblockInfoInstance = new XBlockInfo(encodedXBlockDetails);
- editXBlockModal.edit([], xblockInfoInstance, {});
+ editXBlockModal.edit([], xblockInfoInstance, {returnTo: returnTo || null});
}
});
%static:webpack>
diff --git a/cms/templates/js/show-correctness-editor.underscore b/cms/templates/js/show-correctness-editor.underscore
index 3db6c3c27a5c..1b0dd896747a 100644
--- a/cms/templates/js/show-correctness-editor.underscore
+++ b/cms/templates/js/show-correctness-editor.underscore
@@ -35,6 +35,13 @@
<% } %>
<%- gettext('If the subsection does not have a due date, learners always see their scores when they submit answers to assessments.') %>
+
+
+ <%- gettext('Learners do not see question-level correctness or scores before or after the due date. However, once the due date passes, they can see their overall score for the subsection on the Progress page.') %>
+
diff --git a/common/djangoapps/student/auth.py b/common/djangoapps/student/auth.py
index e199142fe377..047f0174a062 100644
--- a/common/djangoapps/student/auth.py
+++ b/common/djangoapps/student/auth.py
@@ -24,6 +24,7 @@
OrgInstructorRole,
OrgLibraryUserRole,
OrgStaffRole,
+ strict_role_checking,
)
# Studio permissions:
@@ -115,8 +116,9 @@ def get_user_permissions(user, course_key, org=None, service_variant=None):
return STUDIO_NO_PERMISSIONS
# Staff have all permissions except EDIT_ROLES:
- if OrgStaffRole(org=org).has_user(user) or (course_key and user_has_role(user, CourseStaffRole(course_key))):
- return STUDIO_VIEW_USERS | STUDIO_EDIT_CONTENT | STUDIO_VIEW_CONTENT
+ with strict_role_checking():
+ if OrgStaffRole(org=org).has_user(user) or (course_key and user_has_role(user, CourseStaffRole(course_key))):
+ return STUDIO_VIEW_USERS | STUDIO_EDIT_CONTENT | STUDIO_VIEW_CONTENT
# Otherwise, for libraries, users can view only:
if course_key and isinstance(course_key, LibraryLocator):
diff --git a/common/djangoapps/student/tests/test_authz.py b/common/djangoapps/student/tests/test_authz.py
index c0b88e6318b5..70636e04b68a 100644
--- a/common/djangoapps/student/tests/test_authz.py
+++ b/common/djangoapps/student/tests/test_authz.py
@@ -11,6 +11,7 @@
from django.test import TestCase, override_settings
from opaque_keys.edx.locator import CourseLocator
+from common.djangoapps.student.models.user import CourseAccessRole
from common.djangoapps.student.auth import (
add_users,
has_studio_read_access,
@@ -305,6 +306,23 @@ def test_limited_staff_no_studio_access_cms(self):
assert not has_studio_read_access(self.limited_staff, self.course_key)
assert not has_studio_write_access(self.limited_staff, self.course_key)
+ @override_settings(SERVICE_VARIANT='cms')
+ def test_limited_org_staff_no_studio_access_cms(self):
+ """
+ Verifies that course limited staff have no read and no write access when SERVICE_VARIANT is not 'lms'.
+ """
+ # Add a user as course_limited_staff on the org
+ # This is not possible using the course roles classes but is possible via Django admin so we
+ # insert a row into the model directly to test that scenario.
+ CourseAccessRole.objects.create(
+ user=self.limited_staff,
+ org=self.course_key.org,
+ role=CourseLimitedStaffRole.ROLE,
+ )
+
+ assert not has_studio_read_access(self.limited_staff, self.course_key)
+ assert not has_studio_write_access(self.limited_staff, self.course_key)
+
class CourseOrgGroupTest(TestCase):
"""
diff --git a/common/djangoapps/third_party_auth/api/serializers.py b/common/djangoapps/third_party_auth/api/serializers.py
index 3e8513de7312..a510cbe07a07 100644
--- a/common/djangoapps/third_party_auth/api/serializers.py
+++ b/common/djangoapps/third_party_auth/api/serializers.py
@@ -20,4 +20,7 @@ def get_username(self, social_user):
def get_remote_id(self, social_user):
""" Gets remote id from social user based on provider """
+ remote_id_field_name = self.context.get('remote_id_field_name', None)
+ if remote_id_field_name:
+ return self.provider.get_remote_id_from_field_name(social_user, remote_id_field_name)
return self.provider.get_remote_id_from_social_auth(social_user)
diff --git a/common/djangoapps/third_party_auth/api/tests/test_views.py b/common/djangoapps/third_party_auth/api/tests/test_views.py
index f7834001d66b..61740268db90 100644
--- a/common/djangoapps/third_party_auth/api/tests/test_views.py
+++ b/common/djangoapps/third_party_auth/api/tests/test_views.py
@@ -38,8 +38,10 @@
PASSWORD = "edx"
-def get_mapping_data_by_usernames(usernames):
+def get_mapping_data_by_usernames(usernames, remote_id_field_name=False):
""" Generate mapping data used in response """
+ if remote_id_field_name:
+ return [{'username': username, 'remote_id': 'external_' + username} for username in usernames]
return [{'username': username, 'remote_id': 'remote_' + username} for username in usernames]
@@ -76,11 +78,13 @@ def setUp(self): # pylint: disable=arguments-differ
provider=google.backend_name,
uid=f'{username}@gmail.com',
)
- UserSocialAuth.objects.create(
+ usa = UserSocialAuth.objects.create(
user=user,
provider=testshib.backend_name,
uid=f'{testshib.slug}:remote_{username}',
)
+ usa.set_extra_data({'external_user_id': f'external_{username}'})
+ usa.refresh_from_db()
# Create another user not linked to any providers:
UserFactory.create(username=CARL_USERNAME, email=f'{CARL_USERNAME}@example.com', password=PASSWORD)
@@ -304,12 +308,20 @@ def test_list_all_user_mappings_oauth2(self, valid_call, expect_code, expect_dat
@ddt.data(
({'username': [ALICE_USERNAME, STAFF_USERNAME]}, 200,
get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME])),
+ ({'username': [ALICE_USERNAME, STAFF_USERNAME], 'remote_id_field_name': 'external_user_id'}, 200,
+ get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME], remote_id_field_name=True)),
({'remote_id': ['remote_' + ALICE_USERNAME, 'remote_' + STAFF_USERNAME, 'remote_' + CARL_USERNAME]}, 200,
get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME])),
+ ({'remote_id': ['remote_' + ALICE_USERNAME, 'remote_' + STAFF_USERNAME, 'remote_' + CARL_USERNAME],
+ 'remote_id_field_name': 'external_user_id'}, 200,
+ get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME], remote_id_field_name=True)),
({'username': [ALICE_USERNAME, CARL_USERNAME, STAFF_USERNAME]}, 200,
get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME])),
({'username': [ALICE_USERNAME], 'remote_id': ['remote_' + STAFF_USERNAME]}, 200,
get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME])),
+ ({'username': [ALICE_USERNAME], 'remote_id': ['remote_' + STAFF_USERNAME],
+ 'remote_id_field_name': 'external_user_id'}, 200,
+ get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME], remote_id_field_name=True)),
)
@ddt.unpack
def test_user_mappings_with_query_params_comma_separated(self, query_params, expect_code, expect_data):
@@ -321,6 +333,8 @@ def test_user_mappings_with_query_params_comma_separated(self, query_params, exp
for attr in ['username', 'remote_id']:
if attr in query_params:
params.append('{}={}'.format(attr, ','.join(query_params[attr])))
+ if 'remote_id_field_name' in query_params:
+ params.append('remote_id_field_name={}'.format(query_params['remote_id_field_name']))
url = "{}?{}".format(base_url, '&'.join(params))
response = self.client.get(url, HTTP_X_EDX_API_KEY=VALID_API_KEY)
self._verify_response(response, expect_code, expect_data)
@@ -328,12 +342,20 @@ def test_user_mappings_with_query_params_comma_separated(self, query_params, exp
@ddt.data(
({'username': [ALICE_USERNAME, STAFF_USERNAME]}, 200,
get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME])),
+ ({'username': [ALICE_USERNAME, STAFF_USERNAME], 'remote_id_field_name': 'external_user_id'}, 200,
+ get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME], remote_id_field_name=True)),
({'remote_id': ['remote_' + ALICE_USERNAME, 'remote_' + STAFF_USERNAME, 'remote_' + CARL_USERNAME]}, 200,
get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME])),
+ ({'remote_id': ['remote_' + ALICE_USERNAME, 'remote_' + STAFF_USERNAME, 'remote_' + CARL_USERNAME],
+ 'remote_id_field_name': 'external_user_id'}, 200,
+ get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME], remote_id_field_name=True)),
({'username': [ALICE_USERNAME, CARL_USERNAME, STAFF_USERNAME]}, 200,
get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME])),
({'username': [ALICE_USERNAME], 'remote_id': ['remote_' + STAFF_USERNAME]}, 200,
get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME])),
+ ({'username': [ALICE_USERNAME], 'remote_id': ['remote_' + STAFF_USERNAME],
+ 'remote_id_field_name': 'external_user_id'}, 200,
+ get_mapping_data_by_usernames([ALICE_USERNAME, STAFF_USERNAME], remote_id_field_name=True)),
)
@ddt.unpack
def test_user_mappings_with_query_params_multi_value_key(self, query_params, expect_code, expect_data):
@@ -345,6 +367,8 @@ def test_user_mappings_with_query_params_multi_value_key(self, query_params, exp
for attr in ['username', 'remote_id']:
if attr in query_params:
params.setlist(attr, query_params[attr])
+ if 'remote_id_field_name' in query_params:
+ params['remote_id_field_name'] = query_params['remote_id_field_name']
url = f"{base_url}?{params.urlencode()}"
response = self.client.get(url, HTTP_X_EDX_API_KEY=VALID_API_KEY)
self._verify_response(response, expect_code, expect_data)
diff --git a/common/djangoapps/third_party_auth/api/views.py b/common/djangoapps/third_party_auth/api/views.py
index c2b8b0dd6f39..89d55e2eecdc 100644
--- a/common/djangoapps/third_party_auth/api/views.py
+++ b/common/djangoapps/third_party_auth/api/views.py
@@ -323,6 +323,9 @@ class UserMappingView(ListAPIView):
GET /api/third_party_auth/v0/providers/{provider_id}/users?username={username1},{username2}
+ GET /api/third_party_auth/v0/providers/{provider_id}/users?username={username1}&
+ remote_id_field_name={external_id_field_name}
+
GET /api/third_party_auth/v0/providers/{provider_id}/users?username={username1}&usernames={username2}
GET /api/third_party_auth/v0/providers/{provider_id}/users?remote_id={remote_id1},{remote_id2}
@@ -346,6 +349,9 @@ class UserMappingView(ListAPIView):
* usernames: Optional. List of comma separated edX usernames to filter the result set.
e.g. ?usernames=bob123,jane456
+ * remote_id_field_name: Optional. The field name to use for the remote id lookup.
+ Useful when learners are coming from external LMS. e.g. ?remote_id_field_name=ext_userid_sf
+
* page, page_size: Optional. Used for paging the result set, especially when getting
an unfiltered list.
@@ -415,6 +421,7 @@ def get_serializer_context(self):
remove idp_slug from the remote_id if there is any
"""
context = super().get_serializer_context()
+ context['remote_id_field_name'] = self.request.query_params.get('remote_id_field_name', None)
context['provider'] = self.provider
return context
diff --git a/common/djangoapps/third_party_auth/docs/how_tos/testing_saml_locally.rst b/common/djangoapps/third_party_auth/docs/how_tos/testing_saml_locally.rst
new file mode 100644
index 000000000000..2be1cc1dacf6
--- /dev/null
+++ b/common/djangoapps/third_party_auth/docs/how_tos/testing_saml_locally.rst
@@ -0,0 +1,156 @@
+Testing SAML Authentication Locally with MockSAML
+==================================================
+
+This guide walks through setting up and testing SAML authentication in a local Open edX devstack environment using MockSAML.com as a test Identity Provider (IdP).
+
+Overview
+--------
+
+SAML (Security Assertion Markup Language) authentication in Open edX requires three configuration objects to work together:
+
+1. **SAMLConfiguration**: Configures the Service Provider (SP) metadata - entity ID, keys, and organization info
+2. **SAMLProviderConfig**: Configures a specific Identity Provider (IdP) connection with metadata URL and attribute mappings
+3. **SAMLProviderData**: Stores the IdP's metadata (SSO URL, public key) fetched from the IdP's metadata endpoint
+
+**Critical Requirement**: The SAMLConfiguration object MUST have the slug "default" because this value is hardcoded in the authentication execution path at ``common/djangoapps/third_party_auth/models.py:906``.
+
+Prerequisites
+-------------
+
+* Local Open edX devstack running
+* Access to Django admin at http://localhost:18000/admin/
+* MockSAML.com account (free service for SAML testing)
+
+Step 1: Configure SAMLConfiguration
+------------------------------------
+
+The SAMLConfiguration defines your Open edX instance as a SAML Service Provider (SP).
+
+1. Navigate to Django Admin → Third Party Auth → SAML Configurations
+2. Click "Add SAML Configuration"
+3. Configure with these **required** values:
+
+ ============ ===================================================
+ Field Value
+ ============ ===================================================
+ Site localhost:18000
+ **Slug** **default** (MUST be "default" - hardcoded in code)
+ Entity ID https://saml.example.com/entityid
+ Enabled ✓ (checked)
+ ============ ===================================================
+
+4. For local testing with MockSAML, you can leave the keys blank.
+
+5. Optionally configure Organization Info (use default or customize):
+
+ .. code-block:: json
+
+ {
+ "en-US": {
+ "url": "http://localhost:18000",
+ "displayname": "Local Open edX",
+ "name": "localhost"
+ }
+ }
+
+6. Click "Save"
+
+Step 2: Configure SAMLProviderConfig
+-------------------------------------
+
+The SAMLProviderConfig connects to a specific SAML Identity Provider (MockSAML in this case).
+
+1. Navigate to Django Admin → Third Party Auth → Provider Configuration (SAML IdPs)
+2. Click "Add Provider Configuration (SAML IdP)"
+3. Configure with these values:
+
+ ========================= ===================================================
+ Field Value
+ ========================= ===================================================
+ Name Test Localhost (or any descriptive name)
+ Slug default (to match test URLs)
+ Backend Name tpa-saml
+ Entity ID https://saml.example.com/entityid
+ Metadata Source https://mocksaml.com/api/saml/metadata
+ Site localhost:18000
+ SAML Configuration Select the SAMLConfiguration created in Step 1
+ Enabled ✓ (checked)
+ Visible ☐ (unchecked for testing)
+ Skip hinted login dialog ✓ (checked - recommended)
+ Skip registration form ✓ (checked - recommended)
+ Skip email verification ✓ (checked - recommended)
+ Send to registration first ✓ (checked - recommended)
+ ========================= ===================================================
+
+4. Leave all attribute mappings (User ID, Email, Full Name, etc.) blank to use defaults
+5. Click "Save"
+
+**Important**: The Entity ID in SAMLProviderConfig MUST match the Entity ID in SAMLConfiguration.
+
+Step 3: Set IdP Data
+--------------------
+
+The SAMLProviderData stores metadata from the Identity Provider (MockSAML), create a record with
+
+* **Entity ID**: https://saml.example.com/entityid
+* **SSO URL**: https://mocksaml.com/api/saml/sso
+* **Public Key**: The IdP's signing certificate
+* **Expires At**: Set to 1 year from fetch time
+
+
+Step 4: Test SAML Authentication
+---------------------------------
+
+1. Navigate to: http://localhost:18000/auth/idp_redirect/saml-default
+2. You should be redirected to MockSAML.com
+3. Complete the authentication on MockSAML - just click "Sign In" with whatever is in the form.
+4. You should be redirected back to Open edX
+5. If this is a new user, you'll see the registration form
+6. After registration, you should be logged in
+
+Expected Behavior
+^^^^^^^^^^^^^^^^^
+
+1. Initial redirect to MockSAML (https://mocksaml.com/api/saml/sso)
+2. MockSAML displays the login page
+3. After authentication, MockSAML POSTs the SAML assertion back to Open edX
+4. Open edX validates the assertion and creates/logs in the user
+5. User is redirected to the dashboard or registration form (if new user)
+
+Reference Configuration
+-----------------------
+
+Here's a summary of a working test configuration:
+
+**SAMLConfiguration** (id=6):
+
+* Site: localhost:18000
+* Slug: **default**
+* Entity ID: https://saml.example.com/entityid
+* Enabled: True
+
+**SAMLProviderConfig** (id=11):
+
+* Name: Test Localhost
+* Slug: default
+* Entity ID: https://saml.example.com/entityid
+* Metadata Source: https://mocksaml.com/api/saml/metadata
+* Backend Name: tpa-saml
+* Site: localhost:18000
+* SAML Configuration: → SAMLConfiguration (id=6)
+* Enabled: True
+
+**SAMLProviderData** (id=3):
+
+* Entity ID: https://saml.example.com/entityid
+* SSO URL: https://mocksaml.com/api/saml/sso
+* Public Key: (certificate from MockSAML metadata)
+* Fetched At: 2026-02-27 18:05:40+00:00
+* Expires At: 2027-02-27 18:05:41+00:00
+* Valid: True
+
+**MockSAML Configuration**:
+
+* SP Entity ID: https://saml.example.com/entityid
+* ACS URL: http://localhost:18000/auth/complete/tpa-saml/
+* Test User Attributes: email, firstName, lastName, uid
diff --git a/common/djangoapps/third_party_auth/management/commands/saml.py b/common/djangoapps/third_party_auth/management/commands/saml.py
index afe369c2ade0..6865ebf69987 100644
--- a/common/djangoapps/third_party_auth/management/commands/saml.py
+++ b/common/djangoapps/third_party_auth/management/commands/saml.py
@@ -6,7 +6,6 @@
import logging
from django.core.management.base import BaseCommand, CommandError
-from edx_django_utils.monitoring import set_custom_attribute
from common.djangoapps.third_party_auth.tasks import fetch_saml_metadata
from common.djangoapps.third_party_auth.models import SAMLProviderConfig, SAMLConfiguration
@@ -71,31 +70,28 @@ def _handle_run_checks(self):
"""
Handle the --run-checks option for checking SAMLProviderConfig configuration issues.
- This is a report-only command. It identifies potential configuration problems such as:
- - Outdated SAMLConfiguration references (provider pointing to old config version)
- - Site ID mismatches between SAMLProviderConfig and its SAMLConfiguration
- - Slug mismatches (except 'default' slugs) # noqa: E501
- - SAMLProviderConfig objects with null SAMLConfiguration references (informational)
-
- Includes observability attributes for monitoring.
+ This is a report-only command that identifies potential configuration problems.
"""
- # Set custom attributes for monitoring the check operation
- # .. custom_attribute_name: saml_management_command.operation
- # .. custom_attribute_description: Records current SAML operation ('run_checks').
- set_custom_attribute('saml_management_command.operation', 'run_checks')
-
metrics = self._check_provider_configurations()
self._report_check_summary(metrics)
def _check_provider_configurations(self):
"""
- Check each provider configuration for potential issues.
+ Check each provider configuration for potential issues:
+ - Outdated configuration references
+ - Site ID mismatches
+ - Missing configurations (no direct config and no default)
+ - Disabled providers and configurations
+ Also reports informational data such as slug mismatches.
+
+ See code comments near each log output for possible resolution details.
Returns a dictionary of metrics about the found issues.
"""
outdated_count = 0
site_mismatch_count = 0
slug_mismatch_count = 0
null_config_count = 0
+ disabled_config_count = 0
error_count = 0
total_providers = 0
@@ -107,53 +103,74 @@ def _check_provider_configurations(self):
for provider_config in provider_configs:
total_providers += 1
+
+ # Check if provider is disabled
+ provider_disabled = not provider_config.enabled
+ disabled_status = ", enabled=False" if provider_disabled else ""
+
provider_info = (
- f"Provider (id={provider_config.id}, name={provider_config.name}, "
- f"slug={provider_config.slug}, site_id={provider_config.site_id})"
+ f"Provider (id={provider_config.id}, "
+ f"name={provider_config.name}, slug={provider_config.slug}, "
+ f"site_id={provider_config.site_id}{disabled_status})"
)
- if not provider_config.saml_configuration:
- self.stdout.write(
- f"[INFO] {provider_info} has no SAML configuration because "
- "a matching default was not found."
- )
- null_config_count += 1
- continue
+ # Provider disabled status is already included in provider_info format
try:
+ if not provider_config.saml_configuration:
+ null_config_count, disabled_config_count = self._check_no_config(
+ provider_config, provider_info, null_config_count, disabled_config_count
+ )
+ continue
+
+ # Check if SAML configuration is disabled
+ if not provider_config.saml_configuration.enabled:
+ # Resolution: Enable the SAML configuration in Django admin
+ # or assign a different configuration
+ self.stdout.write(
+ f"[WARNING] {provider_info} "
+ f"has SAML config (id={provider_config.saml_configuration_id}, enabled=False)."
+ )
+ disabled_config_count += 1
+
+ # Check configuration currency
current_config = SAMLConfiguration.current(
provider_config.saml_configuration.site_id,
provider_config.saml_configuration.slug
)
- # Check for outdated configuration references
- if current_config:
- if current_config.id != provider_config.saml_configuration_id:
- self.stdout.write(
- f"[WARNING] {provider_info} "
- f"has outdated SAML config (id={provider_config.saml_configuration_id} which "
- f"should be updated to the current SAML config (id={current_config.id})."
- )
- outdated_count += 1
+ if current_config and (current_config.id != provider_config.saml_configuration_id):
+ # Resolution: Update the provider's saml_configuration_id to the current config ID
+ self.stdout.write(
+ f"[WARNING] {provider_info} "
+ f"has outdated SAML config (id={provider_config.saml_configuration_id}) which "
+ f"should be updated to the current SAML config (id={current_config.id})."
+ )
+ outdated_count += 1
+ # Check site ID match
if provider_config.saml_configuration.site_id != provider_config.site_id:
config_site_id = provider_config.saml_configuration.site_id
- provider_site_id = provider_config.site_id
+ # Resolution: Create a new SAML configuration for the correct site
+ # or move the provider to the matching site
self.stdout.write(
f"[WARNING] {provider_info} "
- f"SAML config (id={provider_config.saml_configuration_id}, site_id={config_site_id}) "
- "does not match the provider's site_id."
+ f"SAML config (id={provider_config.saml_configuration_id}, "
+ f"site_id={config_site_id}) does not match the provider's site_id."
)
site_mismatch_count += 1
- saml_configuration_slug = provider_config.saml_configuration.slug
- provider_config_slug = provider_config.slug
-
- if saml_configuration_slug not in (provider_config_slug, 'default'):
+ # Check slug match
+ if provider_config.saml_configuration.slug not in (provider_config.slug, 'default'):
+ config_id = provider_config.saml_configuration_id
+ saml_configuration_slug = provider_config.saml_configuration.slug
+ config_disabled_status = ", enabled=False" if not provider_config.saml_configuration.enabled else ""
+ # Resolution: This is informational only - provider can use
+ # a different slug configuration
self.stdout.write(
- f"[WARNING] {provider_info} "
- f"SAML config (id={provider_config.saml_configuration_id}, slug='{saml_configuration_slug}') "
- "does not match the provider's slug."
+ f"[INFO] {provider_info} has "
+ f"SAML config (id={config_id}, slug='{saml_configuration_slug}'{config_disabled_status}) "
+ "that does not match the provider's slug."
)
slug_mismatch_count += 1
@@ -165,41 +182,64 @@ def _check_provider_configurations(self):
'total_providers': {'count': total_providers, 'requires_attention': False},
'outdated_count': {'count': outdated_count, 'requires_attention': True},
'site_mismatch_count': {'count': site_mismatch_count, 'requires_attention': True},
- 'slug_mismatch_count': {'count': slug_mismatch_count, 'requires_attention': True},
+ 'slug_mismatch_count': {'count': slug_mismatch_count, 'requires_attention': False},
'null_config_count': {'count': null_config_count, 'requires_attention': False},
+ 'disabled_config_count': {'count': disabled_config_count, 'requires_attention': True},
'error_count': {'count': error_count, 'requires_attention': True},
}
- for key, metric_data in metrics.items():
- # .. custom_attribute_name: saml_management_command.{key}
- # .. custom_attribute_description: Records metrics from SAML configuration checks.
- set_custom_attribute(f'saml_management_command.{key}', metric_data['count'])
-
return metrics
+ def _check_no_config(self, provider_config, provider_info, null_config_count, disabled_config_count):
+ """Helper to check providers with no direct SAML configuration."""
+ default_config = SAMLConfiguration.current(provider_config.site_id, 'default')
+ if not default_config or default_config.id is None:
+ # Resolution: Create/Link a SAML configuration for this provider
+ # or create/link a default configuration for the site
+ self.stdout.write(
+ f"[WARNING] {provider_info} has no direct SAML configuration and "
+ "no matching default configuration was found."
+ )
+ null_config_count += 1
+
+ elif not default_config.enabled:
+ # Resolution: Enable the provider's linked SAML configuration
+ # or create/link a specific configuration for this provider
+ self.stdout.write(
+ f"[WARNING] {provider_info} has no direct SAML configuration and "
+ f"the default configuration (id={default_config.id}, enabled=False)."
+ )
+ disabled_config_count += 1
+
+ return null_config_count, disabled_config_count
+
def _report_check_summary(self, metrics):
"""
- Print a summary of the check results and set the total_requiring_attention custom attribute.
+ Print a summary of the check results.
"""
total_requiring_attention = sum(
metric_data['count'] for metric_data in metrics.values()
if metric_data['requires_attention']
)
- # .. custom_attribute_name: saml_management_command.total_requiring_attention
- # .. custom_attribute_description: The total number of configuration issues requiring attention.
- set_custom_attribute('saml_management_command.total_requiring_attention', total_requiring_attention)
-
self.stdout.write(self.style.SUCCESS("CHECK SUMMARY:"))
self.stdout.write(f" Providers checked: {metrics['total_providers']['count']}")
- self.stdout.write(f" Null configs: {metrics['null_config_count']['count']}")
+ self.stdout.write("")
+
+ # Informational only section
+ self.stdout.write("Informational only:")
+ self.stdout.write(f" Slug mismatches: {metrics['slug_mismatch_count']['count']}")
+ self.stdout.write(f" Missing configs: {metrics['null_config_count']['count']}")
+ self.stdout.write("")
+ # Issues requiring attention section
if total_requiring_attention > 0:
- self.stdout.write("\nIssues requiring attention:")
+ self.stdout.write("Issues requiring attention:")
self.stdout.write(f" Outdated: {metrics['outdated_count']['count']}")
self.stdout.write(f" Site mismatches: {metrics['site_mismatch_count']['count']}")
- self.stdout.write(f" Slug mismatches: {metrics['slug_mismatch_count']['count']}")
+ self.stdout.write(f" Disabled configs: {metrics['disabled_config_count']['count']}")
self.stdout.write(f" Errors: {metrics['error_count']['count']}")
- self.stdout.write(f"\nTotal issues requiring attention: {total_requiring_attention}")
+ self.stdout.write("")
+ self.stdout.write(f"Total issues requiring attention: {total_requiring_attention}")
else:
- self.stdout.write(self.style.SUCCESS("\nNo configuration issues found!"))
+ self.stdout.write(self.style.SUCCESS("No configuration issues found!"))
diff --git a/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py b/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py
index 6963d5dcd0d5..d80c9146664b 100644
--- a/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py
+++ b/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py
@@ -79,6 +79,7 @@ def setUp(self):
name='TestShib College',
entity_id='https://idp.testshib.org/idp/shibboleth',
metadata_source='https://www.testshib.org/metadata/testshib-providers.xml',
+ saml_configuration=self.saml_config,
)
def _setup_test_configs_for_run_checks(self):
@@ -337,8 +338,30 @@ def _run_checks_command(self):
call_command('saml', '--run-checks', stdout=out)
return out.getvalue()
- @mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute')
- def test_run_checks_outdated_configs(self, mock_set_custom_attribute):
+ def test_run_checks_setup_test_data(self):
+ """
+ Test the --run-checks command against initial setup test data.
+
+ This test validates that the base setup data (from setUp) is correctly
+ identified as having configuration issues. The setup includes a provider
+ (self.provider_config) with a disabled SAML configuration (self.saml_config),
+ which is reported as a disabled config issue (not a missing config).
+ """
+ output = self._run_checks_command()
+
+ # The setup data includes a provider with a disabled SAML config
+ expected_warning = (
+ f'[WARNING] Provider (id={self.provider_config.id}, '
+ f'name={self.provider_config.name}, '
+ f'slug={self.provider_config.slug}, '
+ f'site_id={self.provider_config.site_id}) '
+ f'has SAML config (id={self.saml_config.id}, enabled=False).'
+ )
+ self.assertIn(expected_warning, output)
+ self.assertIn('Missing configs: 0', output) # No missing configs from setUp
+ self.assertIn('Disabled configs: 1', output) # From setUp: provider_config with disabled saml_config
+
+ def test_run_checks_outdated_configs(self):
"""
Test the --run-checks command identifies outdated configurations.
"""
@@ -346,31 +369,18 @@ def test_run_checks_outdated_configs(self, mock_set_custom_attribute):
output = self._run_checks_command()
- self.assertIn('[WARNING]', output)
- self.assertIn('test-provider', output)
- self.assertIn(
- f'id={old_config.id} which should be updated to the current SAML config (id={new_config.id})',
- output
+ expected_warning = (
+ f'[WARNING] Provider (id={test_provider_config.id}, name={test_provider_config.name}, '
+ f'slug={test_provider_config.slug}, site_id={test_provider_config.site_id}) '
+ f'has outdated SAML config (id={old_config.id}) which should be updated to '
+ f'the current SAML config (id={new_config.id}).'
)
- self.assertIn('CHECK SUMMARY:', output)
- self.assertIn('Providers checked: 2', output)
+ self.assertIn(expected_warning, output)
self.assertIn('Outdated: 1', output)
+ # Total includes: 1 outdated + 2 disabled configs (setUp + test's old_config which is also disabled)
+ self.assertIn('Total issues requiring attention: 3', output)
- # Check key observability calls
- expected_calls = [
- mock.call('saml_management_command.operation', 'run_checks'),
- mock.call('saml_management_command.total_providers', 2),
- mock.call('saml_management_command.outdated_count', 1),
- mock.call('saml_management_command.site_mismatch_count', 0),
- mock.call('saml_management_command.slug_mismatch_count', 1),
- mock.call('saml_management_command.null_config_count', 1),
- mock.call('saml_management_command.error_count', 0),
- mock.call('saml_management_command.total_requiring_attention', 2),
- ]
- mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)
-
- @mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute')
- def test_run_checks_site_mismatches(self, mock_set_custom_attribute):
+ def test_run_checks_site_mismatches(self):
"""
Test the --run-checks command identifies site ID mismatches.
"""
@@ -380,7 +390,7 @@ def test_run_checks_site_mismatches(self, mock_set_custom_attribute):
entity_id='https://example.com'
)
- SAMLProviderConfigFactory.create(
+ provider = SAMLProviderConfigFactory.create(
site=self.site,
slug='test-provider',
saml_configuration=config
@@ -388,25 +398,17 @@ def test_run_checks_site_mismatches(self, mock_set_custom_attribute):
output = self._run_checks_command()
- self.assertIn('[WARNING]', output)
- self.assertIn('test-provider', output)
- self.assertIn('does not match the provider\'s site_id', output)
-
- # Check observability calls
- expected_calls = [
- mock.call('saml_management_command.operation', 'run_checks'),
- mock.call('saml_management_command.total_providers', 2),
- mock.call('saml_management_command.outdated_count', 0),
- mock.call('saml_management_command.site_mismatch_count', 1),
- mock.call('saml_management_command.slug_mismatch_count', 1),
- mock.call('saml_management_command.null_config_count', 1),
- mock.call('saml_management_command.error_count', 0),
- mock.call('saml_management_command.total_requiring_attention', 2),
- ]
- mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)
-
- @mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute')
- def test_run_checks_slug_mismatches(self, mock_set_custom_attribute):
+ expected_warning = (
+ f'[WARNING] Provider (id={provider.id}, name={provider.name}, '
+ f'slug={provider.slug}, site_id={provider.site_id}) '
+ f'SAML config (id={config.id}, site_id={config.site_id}) does not match the provider\'s site_id.'
+ )
+ self.assertIn(expected_warning, output)
+ self.assertIn('Site mismatches: 1', output)
+ # Total includes: 1 site mismatch + 1 disabled config (from setUp)
+ self.assertIn('Total issues requiring attention: 2', output)
+
+ def test_run_checks_slug_mismatches(self):
"""
Test the --run-checks command identifies slug mismatches.
"""
@@ -416,7 +418,7 @@ def test_run_checks_slug_mismatches(self, mock_set_custom_attribute):
entity_id='https://example.com'
)
- SAMLProviderConfigFactory.create(
+ provider = SAMLProviderConfigFactory.create(
site=self.site,
slug='provider-slug',
saml_configuration=config
@@ -424,29 +426,23 @@ def test_run_checks_slug_mismatches(self, mock_set_custom_attribute):
output = self._run_checks_command()
- self.assertIn('[WARNING]', output)
- self.assertIn('provider-slug', output)
- self.assertIn('does not match the provider\'s slug', output)
-
- # Check observability calls
- expected_calls = [
- mock.call('saml_management_command.operation', 'run_checks'),
- mock.call('saml_management_command.total_providers', 2),
- mock.call('saml_management_command.outdated_count', 0),
- mock.call('saml_management_command.site_mismatch_count', 0),
- mock.call('saml_management_command.slug_mismatch_count', 1),
- mock.call('saml_management_command.null_config_count', 1),
- mock.call('saml_management_command.error_count', 0),
- mock.call('saml_management_command.total_requiring_attention', 1),
- ]
- mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)
-
- @mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute')
- def test_run_checks_null_configurations(self, mock_set_custom_attribute):
+ expected_info = (
+ f'[INFO] Provider (id={provider.id}, name={provider.name}, '
+ f'slug={provider.slug}, site_id={provider.site_id}) '
+ f'has SAML config (id={config.id}, slug=\'{config.slug}\') '
+ f'that does not match the provider\'s slug.'
+ )
+ self.assertIn(expected_info, output)
+ self.assertIn('Slug mismatches: 1', output)
+
+ def test_run_checks_null_configurations(self):
"""
Test the --run-checks command identifies providers with null configurations.
+ This test verifies that providers with no direct SAML configuration and no
+ default configuration available are properly reported.
"""
- SAMLProviderConfigFactory.create(
+ # Create a provider with no SAML configuration on a site that has no default config
+ provider = SAMLProviderConfigFactory.create(
site=self.site,
slug='null-provider',
saml_configuration=None
@@ -454,19 +450,101 @@ def test_run_checks_null_configurations(self, mock_set_custom_attribute):
output = self._run_checks_command()
- self.assertIn('[INFO]', output)
- self.assertIn('null-provider', output)
- self.assertIn('has no SAML configuration because a matching default was not found', output)
-
- # Check observability calls
- expected_calls = [
- mock.call('saml_management_command.operation', 'run_checks'),
- mock.call('saml_management_command.total_providers', 2),
- mock.call('saml_management_command.outdated_count', 0),
- mock.call('saml_management_command.site_mismatch_count', 0),
- mock.call('saml_management_command.slug_mismatch_count', 0),
- mock.call('saml_management_command.null_config_count', 2),
- mock.call('saml_management_command.error_count', 0),
- mock.call('saml_management_command.total_requiring_attention', 0),
- ]
- mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)
+ expected_warning = (
+ f'[WARNING] Provider (id={provider.id}, name={provider.name}, '
+ f'slug={provider.slug}, site_id={provider.site_id}) '
+ f'has no direct SAML configuration and no matching default configuration was found.'
+ )
+ self.assertIn(expected_warning, output)
+ self.assertIn('Missing configs: 1', output) # 1 from this test (provider with no config and no default)
+ self.assertIn('Disabled configs: 1', output) # 1 from setUp data
+
+ def test_run_checks_null_config_id(self):
+ """
+ Test the --run-checks command identifies providers with disabled default configurations.
+ When a provider has no direct SAML configuration and the default config is disabled,
+ it should be reported as a missing config issue.
+ """
+ # Create a disabled default configuration for this site
+ disabled_default_config = SAMLConfigurationFactory.create(
+ site=self.site,
+ slug='default',
+ entity_id='https://default.example.com',
+ enabled=False
+ )
+
+ # Create a provider with no direct SAML configuration
+ # It will fall back to the disabled default config
+ provider = SAMLProviderConfigFactory.create(
+ site=self.site,
+ slug='null-id-provider',
+ saml_configuration=None
+ )
+
+ output = self._run_checks_command()
+
+ expected_warning = (
+ f'[WARNING] Provider (id={provider.id}, name={provider.name}, '
+ f'slug={provider.slug}, site_id={provider.site_id}) '
+ f'has no direct SAML configuration and the default configuration '
+ f'(id={disabled_default_config.id}, enabled=False).'
+ )
+ self.assertIn(expected_warning, output)
+ self.assertIn('Missing configs: 0', output) # No missing configs since default config exists
+ self.assertIn('Disabled configs: 2', output) # 1 from this test + 1 from setUp data
+
+ def test_run_checks_with_default_config(self):
+ """
+ Test the --run-checks command correctly handles providers with default configurations.
+ """
+ provider = SAMLProviderConfigFactory.create(
+ site=self.site,
+ slug='default-config-provider',
+ saml_configuration=None
+ )
+
+ default_config = SAMLConfigurationFactory.create(
+ site=self.site,
+ slug='default',
+ entity_id='https://default.example.com'
+ )
+
+ output = self._run_checks_command()
+
+ self.assertIn('Missing configs: 0', output) # This tests provider has valid default config
+ self.assertIn('Disabled configs: 1', output) # From setUp
+
+ def test_run_checks_disabled_functionality(self):
+ """
+ Test the --run-checks command handles disabled providers and configurations.
+ """
+ disabled_provider = SAMLProviderConfigFactory.create(
+ site=self.site,
+ slug='disabled-provider',
+ enabled=False
+ )
+
+ disabled_config = SAMLConfigurationFactory.create(
+ site=self.site,
+ slug='disabled-config',
+ enabled=False
+ )
+
+ provider_with_disabled_config = SAMLProviderConfigFactory.create(
+ site=self.site,
+ slug='provider-with-disabled-config',
+ saml_configuration=disabled_config
+ )
+
+ output = self._run_checks_command()
+
+ expected_warning = (
+ f'[WARNING] Provider (id={provider_with_disabled_config.id}, '
+ f'name={provider_with_disabled_config.name}, '
+ f'slug={provider_with_disabled_config.slug}, '
+ f'site_id={provider_with_disabled_config.site_id}) '
+ f'has SAML config (id={disabled_config.id}, enabled=False).'
+ )
+ self.assertIn(expected_warning, output)
+ self.assertIn('Missing configs: 1', output) # disabled_provider has no config
+ self.assertIn('Disabled configs: 2', output) # setUp's provider + provider_with_disabled_config
diff --git a/common/djangoapps/third_party_auth/migrations/0014_samlproviderconfig_optional_email_checkboxes.py b/common/djangoapps/third_party_auth/migrations/0014_samlproviderconfig_optional_email_checkboxes.py
new file mode 100644
index 000000000000..34fcf3c97b58
--- /dev/null
+++ b/common/djangoapps/third_party_auth/migrations/0014_samlproviderconfig_optional_email_checkboxes.py
@@ -0,0 +1,26 @@
+# Generated migration for adding optional checkbox skip configuration field
+
+from django.db import migrations, models
+import django.utils.translation
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('third_party_auth', '0013_default_site_id_wrapper_function'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='samlproviderconfig',
+ name='skip_registration_optional_checkboxes',
+ field=models.BooleanField(
+ default=False,
+ help_text=django.utils.translation.gettext_lazy(
+ "If enabled, optional checkboxes (marketing emails opt-in, etc.) will not be rendered "
+ "on the registration form for users registering via this provider. When these checkboxes "
+ "are skipped, their values are inferred as False (opted out)."
+ ),
+ ),
+ ),
+ ]
diff --git a/common/djangoapps/third_party_auth/models.py b/common/djangoapps/third_party_auth/models.py
index 6d244d96eddd..6ae816674b07 100644
--- a/common/djangoapps/third_party_auth/models.py
+++ b/common/djangoapps/third_party_auth/models.py
@@ -745,6 +745,14 @@ class SAMLProviderConfig(ProviderConfig):
"immediately after authenticating with the third party instead of the login page."
),
)
+ skip_registration_optional_checkboxes = models.BooleanField(
+ default=False,
+ help_text=_(
+ "If enabled, optional checkboxes (marketing emails opt-in, etc.) will not be rendered "
+ "on the registration form for users registering via this provider. When these checkboxes "
+ "are skipped, their values are inferred as False (opted out)."
+ ),
+ )
other_settings = models.TextField(
verbose_name="Advanced settings", blank=True,
help_text=(
@@ -803,13 +811,27 @@ def get_url_params(self):
def is_active_for_pipeline(self, pipeline):
""" Is this provider being used for the specified pipeline? """
- return self.backend_name == pipeline['backend'] and self.slug == pipeline['kwargs']['response']['idp_name']
+ try:
+ return self.backend_name == pipeline['backend'] and self.slug == pipeline['kwargs']['response']['idp_name']
+ except KeyError:
+ return False
def match_social_auth(self, social_auth):
""" Is this provider being used for this UserSocialAuth entry? """
prefix = self.slug + ":"
return self.backend_name == social_auth.provider and social_auth.uid.startswith(prefix)
+ def get_remote_id_from_field_name(self, social_auth, field_name):
+ """ Given a UserSocialAuth object, return the user remote ID against the field name provided. """
+ if not self.match_social_auth(social_auth):
+ raise ValueError(
+ f"UserSocialAuth record does not match given provider {self.provider_id}"
+ )
+ field_value = social_auth.extra_data.get(field_name, None)
+ if field_value and isinstance(field_value, list):
+ return field_value[0]
+ return field_value
+
def get_remote_id_from_social_auth(self, social_auth):
""" Given a UserSocialAuth object, return the remote ID used by this provider. """
assert self.match_social_auth(social_auth)
diff --git a/common/djangoapps/third_party_auth/pipeline.py b/common/djangoapps/third_party_auth/pipeline.py
index 496cfce93c1f..d620d27f5357 100644
--- a/common/djangoapps/third_party_auth/pipeline.py
+++ b/common/djangoapps/third_party_auth/pipeline.py
@@ -62,6 +62,7 @@ def B(*args, **kwargs):
import hashlib
import hmac
import json
+import urllib.parse
from collections import OrderedDict
from logging import getLogger
from smtplib import SMTPException
@@ -101,6 +102,10 @@ def B(*args, **kwargs):
is_saml_provider,
user_exists,
)
+from common.djangoapps.third_party_auth.toggles import (
+ is_saml_provider_site_fallback_enabled,
+ is_tpa_next_url_on_dispatch_enabled,
+)
from common.djangoapps.track import segment
from common.djangoapps.util.json_request import JsonResponse
@@ -358,7 +363,11 @@ def get_complete_url(backend_name):
ValueError: if no provider is enabled with the given backend_name.
"""
if not any(provider.Registry.get_enabled_by_backend_name(backend_name)):
- raise ValueError('Provider with backend %s not enabled' % backend_name)
+ # When the SAML site-fallback flag is on, the provider may not be visible to the
+ # site-filtered registry even though SAML auth already completed via a
+ # site-independent lookup. Allow get_complete_url to proceed in that case.
+ if not (is_saml_provider_site_fallback_enabled() and backend_name == 'tpa-saml'):
+ raise ValueError('Provider with backend %s not enabled' % backend_name)
return _get_url('social:complete', backend_name)
@@ -576,13 +585,23 @@ def ensure_user_information(strategy, auth_entry, backend=None, user=None, socia
# It is important that we always execute the entire pipeline. Even if
# behavior appears correct without executing a step, it means important
# invariants have been violated and future misbehavior is likely.
+ def _build_redirect_url(base_url):
+ """Append ?next=… to the redirect URL if the session carries a destination."""
+ if not is_tpa_next_url_on_dispatch_enabled():
+ return base_url
+ next_url = strategy.session_get('next')
+ if next_url and isinstance(next_url, str):
+ separator = '&' if '?' in base_url else '?'
+ base_url = f'{base_url}{separator}next={urllib.parse.quote(next_url)}'
+ return base_url
+
def dispatch_to_login():
"""Redirects to the login page."""
- return redirect(AUTH_DISPATCH_URLS[AUTH_ENTRY_LOGIN])
+ return redirect(_build_redirect_url(AUTH_DISPATCH_URLS[AUTH_ENTRY_LOGIN]))
def dispatch_to_register():
"""Redirects to the registration page."""
- return redirect(AUTH_DISPATCH_URLS[AUTH_ENTRY_REGISTER])
+ return redirect(_build_redirect_url(AUTH_DISPATCH_URLS[AUTH_ENTRY_REGISTER]))
def should_force_account_creation():
""" For some third party providers, we auto-create user accounts """
@@ -603,13 +622,35 @@ def is_provider_saml():
if not user:
# Use only email for user existence check in case of saml provider
- if is_provider_saml():
+ _is_saml = is_provider_saml()
+ _provider_obj = provider.Registry.get_from_pipeline({'backend': current_partial.backend, 'kwargs': kwargs})
+ logger.info(
+ '[THIRD_PARTY_AUTH] ensure_user_information: auth_entry=%s backend=%s is_provider_saml=%s '
+ 'current_provider=%s skip_email_verification=%s send_to_registration_first=%s '
+ 'email=%s kwargs_response_keys=%s',
+ auth_entry,
+ current_partial.backend,
+ _is_saml,
+ _provider_obj.provider_id if _provider_obj else None,
+ _provider_obj.skip_email_verification if _provider_obj else None,
+ _provider_obj.send_to_registration_first if _provider_obj else None,
+ details.get('email') if details else None,
+ list((kwargs.get('response') or {}).keys()),
+ )
+ if _is_saml:
user_details = {'email': details.get('email')} if details else None
else:
user_details = details
- if user_exists(user_details or {}):
+ _user_exists = user_exists(user_details or {})
+ logger.info(
+ '[THIRD_PARTY_AUTH] ensure_user_information: user_exists=%s user_details_email=%s',
+ _user_exists,
+ (user_details or {}).get('email'),
+ )
+ if _user_exists:
# User has not already authenticated and the details sent over from
# identity provider belong to an existing user.
+ logger.info('[THIRD_PARTY_AUTH] ensure_user_information: dispatching to login (user exists)')
return dispatch_to_login()
if is_api(auth_entry):
@@ -617,8 +658,14 @@ def is_provider_saml():
elif auth_entry == AUTH_ENTRY_LOGIN:
# User has authenticated with the third party provider but we don't know which edX
# account corresponds to them yet, if any.
- if should_force_account_creation():
+ _force = should_force_account_creation()
+ logger.info(
+ '[THIRD_PARTY_AUTH] ensure_user_information: AUTH_ENTRY_LOGIN should_force_account_creation=%s',
+ _force,
+ )
+ if _force:
return dispatch_to_register()
+ logger.info('[THIRD_PARTY_AUTH] ensure_user_information: dispatching to login (no force create)')
return dispatch_to_login()
elif auth_entry == AUTH_ENTRY_REGISTER:
# User has authenticated with the third party provider and now wants to finish
@@ -1009,7 +1056,7 @@ def get_username(strategy, details, backend, user=None, *args, **kwargs): # lin
else:
slug_func = lambda val: val
- if is_auto_generated_username_enabled():
+ if is_auto_generated_username_enabled() and details.get('username') is None:
username = get_auto_generated_username(details)
else:
if email_as_username and details.get('email'):
diff --git a/common/djangoapps/third_party_auth/provider.py b/common/djangoapps/third_party_auth/provider.py
index 5eaf0f1888de..3cb9f8642158 100644
--- a/common/djangoapps/third_party_auth/provider.py
+++ b/common/djangoapps/third_party_auth/provider.py
@@ -16,6 +16,7 @@
SAMLConfiguration,
SAMLProviderConfig
)
+from common.djangoapps.third_party_auth.toggles import is_saml_provider_site_fallback_enabled
class Registry:
@@ -97,6 +98,19 @@ def get_from_pipeline(cls, running_pipeline):
if enabled.is_active_for_pipeline(running_pipeline):
return enabled
+ # Fallback for SAML: SAMLAuthBackend.get_idp() uses SAMLProviderConfig.current()
+ # which has no site check. If the provider's site_id doesn't match the current
+ # site (or SAMLConfiguration isn't enabled for the current site), _enabled_providers()
+ # won't yield it — but the SAML handshake already completed. Look up the provider
+ # directly by idp_name so that pipeline steps like should_force_account_creation()
+ # can still read provider flags.
+ if is_saml_provider_site_fallback_enabled() and running_pipeline.get('backend') == 'tpa-saml':
+ try:
+ idp_name = running_pipeline['kwargs']['response']['idp_name']
+ return SAMLProviderConfig.current(idp_name)
+ except (KeyError, SAMLProviderConfig.DoesNotExist):
+ pass
+
@classmethod
def get_enabled_by_backend_name(cls, backend_name):
"""Generator returning all enabled providers that use the specified
diff --git a/common/djangoapps/third_party_auth/saml.py b/common/djangoapps/third_party_auth/saml.py
index 8e78f9e36fc9..3f1cdf30a9e1 100644
--- a/common/djangoapps/third_party_auth/saml.py
+++ b/common/djangoapps/third_party_auth/saml.py
@@ -4,6 +4,7 @@
import logging
+from urllib.parse import unquote
from copy import deepcopy
import requests
@@ -17,6 +18,7 @@
from social_core.exceptions import AuthForbidden, AuthMissingParameter
from openedx.core.djangoapps.theming.helpers import get_current_request
+from openedx.core.djangoapps.user_authn.utils import is_safe_login_or_logout_redirect
from common.djangoapps.third_party_auth.exceptions import IncorrectConfigurationException
STANDARD_SAML_PROVIDER_KEY = 'standard_saml_provider'
@@ -89,12 +91,71 @@ def auth_complete(self, *args, **kwargs):
"""
Handle exceptions that happen during SAML authentication
"""
+ # For IdP-initiated flows (where the user doesn't first hit /auth/login/...),
+ # allow callers to provide a post-auth redirect by packing it into RelayState.
+ # Store it in the session so the rest of the pipeline behaves consistently.
+ try:
+ request = get_current_request()
+ # Allow RelayState to carry both IdP slug and a post-auth destination.
+ # Format: "|", where is typically a relative LMS path.
+ self._maybe_set_next_url_from_relay_state(request)
+ except Exception: # pylint: disable=broad-exception-caught # pragma: no cover
+ # Never fail auth due to redirect bookkeeping.
+ pass
+
try:
return super().auth_complete(*args, **kwargs)
# We are seeing errors of MultiValueDictKeyError looking for the parameter 'RelayState'.
# We would like to have a more specific error to handle for observability purposes.
except MultiValueDictKeyError as e:
- raise AuthMissingParameter(self.name, e.args[0]) from e
+ raise AuthMissingParameter(self.name, e.args[0] if e.args else '') from e
+
+ @staticmethod
+ def _maybe_set_next_url_from_relay_state(request):
+ """Optionally extract a safe `next` from RelayState and rewrite RelayState to the IdP slug.
+
+ This is specifically to support IdP-initiated flows where Auth0 (and some IdPs) can only
+ reliably influence the SAML POST via RelayState.
+ """
+ if request is None or not hasattr(request, 'POST'):
+ return
+ if not hasattr(request, 'session'):
+ return
+
+ relay_state = None
+ try:
+ relay_state = request.POST.get('RelayState')
+ except Exception: # pylint: disable=broad-exception-caught # pragma: no cover
+ relay_state = None
+
+ if not relay_state or '|' not in str(relay_state):
+ return
+
+ slug_part, next_part = str(relay_state).split('|', 1)
+ slug_part = slug_part.strip()
+ next_part = next_part.strip()
+ if not slug_part or not next_part:
+ return
+
+ # URL-decode next (Auth0 or callers may URL-encode it).
+ next_decoded = unquote(next_part)
+
+ # Only store next if it's safe per existing Open edX redirect policy.
+ if is_safe_login_or_logout_redirect(
+ redirect_to=next_decoded,
+ request_host=request.get_host(),
+ dot_client_id=(request.GET.get('client_id') if hasattr(request, 'GET') else None),
+ require_https=request.is_secure(),
+ ):
+ request.session['next'] = next_decoded
+ else:
+ # RelayState included an unsafe destination; clear any stale 'next' value
+ request.session.pop('next', None)
+
+ # Always rewrite RelayState to just the IdP slug so the SAML backend can locate the provider.
+ post_copy = request.POST.copy()
+ post_copy['RelayState'] = slug_part
+ request._post = post_copy # pylint: disable=protected-access
def get_user_id(self, details, response):
"""
diff --git a/common/djangoapps/third_party_auth/tests/test_pipeline_integration.py b/common/djangoapps/third_party_auth/tests/test_pipeline_integration.py
index 0d0a2cf7241d..13c6749b0851 100644
--- a/common/djangoapps/third_party_auth/tests/test_pipeline_integration.py
+++ b/common/djangoapps/third_party_auth/tests/test_pipeline_integration.py
@@ -379,6 +379,92 @@ def test_redirect_for_saml_based_on_email_only(self, email, expected_redirect_ur
assert response.url == expected_redirect_url
+@ddt.ddt
+class EnsureUserInformationNextUrlTestCase(test.TestCase):
+ """Tests that ensure_user_information forwards session['next'] as a query parameter."""
+
+ def _call_ensure_user_information(self, session_next, auth_entry=pipeline.AUTH_ENTRY_LOGIN,
+ send_to_registration_first=True):
+ """Helper to call ensure_user_information with a controlled session_get('next') value."""
+ mock_provider = mock.MagicMock(
+ send_to_registration_first=send_to_registration_first,
+ skip_email_verification=False,
+ )
+ with mock.patch(
+ 'common.djangoapps.third_party_auth.pipeline.provider.Registry.get_from_pipeline'
+ ) as get_from_pipeline:
+ get_from_pipeline.return_value = mock_provider
+ with mock.patch('social_core.pipeline.partial.partial_prepare') as partial_prepare:
+ partial_prepare.return_value = mock.MagicMock(token='')
+ strategy = mock.MagicMock()
+ strategy.session_get.side_effect = lambda key, *args: (
+ session_next if key == 'next' else mock.DEFAULT
+ )
+ response = pipeline.ensure_user_information(
+ strategy=strategy,
+ backend=None,
+ auth_entry=auth_entry,
+ pipeline_index=0,
+ )
+ return response
+
+ @mock.patch(
+ 'common.djangoapps.third_party_auth.pipeline.is_tpa_next_url_on_dispatch_enabled',
+ return_value=True,
+ )
+ @ddt.data(
+ # (session_next, send_to_registration_first, expected_url)
+ ('/courses/my-course', True, '/register?next=/courses/my-course'),
+ ('/courses/my-course', False, '/login?next=/courses/my-course'),
+ ('/dashboard', True, '/register?next=/dashboard'),
+ )
+ @ddt.unpack
+ def test_next_url_forwarded_to_redirect(self, session_next, send_to_registration_first, expected_url, _flag_mock):
+ """When session contains a 'next' URL, it should be appended as a query parameter."""
+ response = self._call_ensure_user_information(
+ session_next=session_next,
+ send_to_registration_first=send_to_registration_first,
+ )
+ assert response.status_code == 302
+ assert response.url == expected_url
+
+ @mock.patch(
+ 'common.djangoapps.third_party_auth.pipeline.is_tpa_next_url_on_dispatch_enabled',
+ return_value=True,
+ )
+ @ddt.data(None, '')
+ def test_no_next_url_gives_bare_redirect(self, session_next, _flag_mock):
+ """When session has no 'next' URL, the redirect should be bare /register."""
+ response = self._call_ensure_user_information(session_next=session_next)
+ assert response.status_code == 302
+ assert response.url == '/register'
+
+ @mock.patch(
+ 'common.djangoapps.third_party_auth.pipeline.is_tpa_next_url_on_dispatch_enabled',
+ return_value=True,
+ )
+ def test_next_url_with_special_characters_is_encoded(self, _flag_mock):
+ """Special characters in the next URL should be percent-encoded."""
+ response = self._call_ensure_user_information(
+ session_next='/courses/my course?foo=bar&baz=1',
+ )
+ assert response.status_code == 302
+ assert response.url.startswith('/register?next=')
+ # The space and & should be encoded
+ assert '%20' in response.url or '+' in response.url
+ assert 'foo%3Dbar' in response.url or 'foo=bar' in response.url
+
+ @mock.patch(
+ 'common.djangoapps.third_party_auth.pipeline.is_tpa_next_url_on_dispatch_enabled',
+ return_value=False,
+ )
+ def test_flag_disabled_gives_bare_redirect(self, _flag_mock):
+ """When the waffle flag is disabled, the redirect should be bare even with session['next']."""
+ response = self._call_ensure_user_information(session_next='/courses/my-course')
+ assert response.status_code == 302
+ assert response.url == '/register'
+
+
class UserDetailsForceSyncTestCase(TestCase):
"""Tests to ensure learner profile data is properly synced if the provider requires it."""
diff --git a/common/djangoapps/third_party_auth/tests/test_saml.py b/common/djangoapps/third_party_auth/tests/test_saml.py
index 6b966a3e6ea4..3a34a4dabd26 100644
--- a/common/djangoapps/third_party_auth/tests/test_saml.py
+++ b/common/djangoapps/third_party_auth/tests/test_saml.py
@@ -5,7 +5,9 @@
from unittest import mock
+from django.test import RequestFactory
from django.utils.datastructures import MultiValueDictKeyError
+from django.contrib.sessions.middleware import SessionMiddleware
from social_core.exceptions import AuthMissingParameter
from common.djangoapps.third_party_auth.saml import EdXSAMLIdentityProvider, get_saml_idp_class, SAMLAuthBackend
@@ -40,6 +42,14 @@ def test_get_user_details(self):
class TestSAMLAuthBackend(SAMLTestCase):
""" Tests for the SAML backend. """
+ @staticmethod
+ def _add_session(request):
+ """Attach a Django session to a RequestFactory request."""
+ middleware = SessionMiddleware(lambda req: None)
+ middleware.process_request(request)
+ request.session.save()
+ return request
+
@mock.patch('common.djangoapps.third_party_auth.saml.SAMLAuth.auth_complete')
def test_saml_auth_complete(self, super_auth_complete):
super_auth_complete.side_effect = MultiValueDictKeyError('RelayState')
@@ -48,3 +58,49 @@ def test_saml_auth_complete(self, super_auth_complete):
backend.auth_complete()
assert cm.exception.parameter == 'RelayState'
+
+ @mock.patch('common.djangoapps.third_party_auth.saml.get_current_request')
+ @mock.patch('common.djangoapps.third_party_auth.saml.SAMLAuth.auth_complete')
+ def test_relaystate_splits_and_sets_next_when_safe(self, super_auth_complete, get_current_request_mock):
+ """RelayState may include both the IdP slug and a safe `next` destination."""
+ rf = RequestFactory()
+ request = rf.post(
+ '/auth/complete/tpa-saml/',
+ data={
+ 'SAMLResponse': 'ignored',
+ 'RelayState': 'example-idp|/courses/course-v1:edX+DemoX+Demo_Course/course/',
+ },
+ HTTP_HOST=self.hostname,
+ )
+ self._add_session(request)
+ get_current_request_mock.return_value = request
+
+ super_auth_complete.return_value = 'ok'
+ backend = SAMLAuthBackend()
+ assert backend.auth_complete() == 'ok'
+
+ assert request.POST.get('RelayState') == 'example-idp'
+ assert request.session.get('next') == '/courses/course-v1:edX+DemoX+Demo_Course/course/'
+
+ @mock.patch('common.djangoapps.third_party_auth.saml.get_current_request')
+ @mock.patch('common.djangoapps.third_party_auth.saml.SAMLAuth.auth_complete')
+ def test_relaystate_drops_unsafe_next(self, super_auth_complete, get_current_request_mock):
+ """If RelayState contains an unsafe `next`, it is ignored but the slug is preserved."""
+ rf = RequestFactory()
+ request = rf.post(
+ '/auth/complete/tpa-saml/',
+ data={
+ 'SAMLResponse': 'ignored',
+ 'RelayState': 'example-idp|https%3A%2F%2Fevil.example.com%2Fpwn',
+ },
+ HTTP_HOST=self.hostname,
+ )
+ self._add_session(request)
+ get_current_request_mock.return_value = request
+
+ super_auth_complete.return_value = 'ok'
+ backend = SAMLAuthBackend()
+ assert backend.auth_complete() == 'ok'
+
+ assert request.POST.get('RelayState') == 'example-idp'
+ assert request.session.get('next') is None
diff --git a/common/djangoapps/third_party_auth/toggles.py b/common/djangoapps/third_party_auth/toggles.py
index d8f77f0b1cf2..51ba5fd0cdb5 100644
--- a/common/djangoapps/third_party_auth/toggles.py
+++ b/common/djangoapps/third_party_auth/toggles.py
@@ -43,3 +43,55 @@ def is_apple_user_migration_enabled():
Returns a boolean if Apple users migration is in process.
"""
return APPLE_USER_MIGRATION_FLAG.is_enabled()
+
+
+# .. toggle_name: third_party_auth.tpa_next_url_on_dispatch
+# .. toggle_implementation: WaffleFlag
+# .. toggle_default: False
+# .. toggle_description: When enabled, the third-party auth pipeline will forward
+# session['next'] as a ?next= query parameter when redirecting to the login or
+# registration page. This ensures the post-auth destination is preserved for new
+# users who must complete registration before being redirected.
+# .. toggle_use_cases: temporary
+# .. toggle_creation_date: 2026-02-13
+# .. toggle_target_removal_date: 2026-06-01
+# .. toggle_warning: None.
+TPA_NEXT_URL_ON_DISPATCH_FLAG = WaffleFlag(f'{THIRD_PARTY_AUTH_NAMESPACE}.tpa_next_url_on_dispatch', __name__)
+
+
+def is_tpa_next_url_on_dispatch_enabled():
+ """
+ Returns True if the pipeline should forward session['next'] as a query parameter
+ when dispatching to login/register pages.
+ """
+ return TPA_NEXT_URL_ON_DISPATCH_FLAG.is_enabled()
+
+
+# .. toggle_name: third_party_auth.saml_provider_site_fallback
+# .. toggle_implementation: WaffleFlag
+# .. toggle_default: False
+# .. toggle_description: When enabled, Registry.get_from_pipeline() will fall back to a
+# site-independent SAMLProviderConfig lookup when the site-filtered registry returns no
+# match for a running SAML pipeline. This handles cases where the SAMLProviderConfig or
+# SAMLConfiguration is associated with a different Django site than the one currently
+# serving the request, while SAML auth itself already completed (SAMLAuthBackend.get_idp()
+# has no site check). Without this flag, pipeline steps such as should_force_account_creation()
+# cannot read provider flags (e.g. send_to_registration_first), causing new users to land on
+# the login page instead of registration.
+# .. toggle_use_cases: temporary
+# .. toggle_creation_date: 2026-02-19
+# .. toggle_target_removal_date: 2026-06-01
+# .. toggle_warning: The underlying site configuration mismatch should still be fixed in Django
+# admin (SAMLConfiguration and SAMLProviderConfig must reference the correct site). This flag
+# is a temporary workaround until that is resolved.
+SAML_PROVIDER_SITE_FALLBACK_FLAG = WaffleFlag(
+ f'{THIRD_PARTY_AUTH_NAMESPACE}.saml_provider_site_fallback', __name__
+)
+
+
+def is_saml_provider_site_fallback_enabled():
+ """
+ Returns True if get_from_pipeline() should fall back to a site-independent
+ SAMLProviderConfig lookup when the site-filtered registry finds no match.
+ """
+ return SAML_PROVIDER_SITE_FALLBACK_FLAG.is_enabled()
diff --git a/common/static/data/geoip/GeoLite2-Country.mmdb b/common/static/data/geoip/GeoLite2-Country.mmdb
index 5a7ccaaf8748..f51bdd7c336e 100644
Binary files a/common/static/data/geoip/GeoLite2-Country.mmdb and b/common/static/data/geoip/GeoLite2-Country.mmdb differ
diff --git a/common/static/js/vendor/pdfjs/viewer.js b/common/static/js/vendor/pdfjs/viewer.js
index bfa90d05a782..e4e4009d3ce0 100644
--- a/common/static/js/vendor/pdfjs/viewer.js
+++ b/common/static/js/vendor/pdfjs/viewer.js
@@ -27,7 +27,7 @@
'use strict';
-var DEFAULT_URL = 'compressed.tracemonkey-pldi-09.pdf';
+var DEFAULT_URL = '';
var DEFAULT_SCALE_DELTA = 1.1;
var MIN_SCALE = 0.25;
var MAX_SCALE = 10.0;
diff --git a/lms/djangoapps/branding/tests/test_views.py b/lms/djangoapps/branding/tests/test_views.py
index 36ebcd73509e..10c27192d11a 100644
--- a/lms/djangoapps/branding/tests/test_views.py
+++ b/lms/djangoapps/branding/tests/test_views.py
@@ -269,6 +269,20 @@ def test_index_does_not_redirect_without_site_override(self):
response = self.client.get(reverse("root"))
assert response.status_code == 200
+ @override_settings(ENABLE_MKTG_SITE=True)
+ @override_settings(MKTG_URLS={'ROOT': 'https://foo.bar/'})
+ @override_settings(LMS_ROOT_URL='https://foo.bar/')
+ def test_index_wont_redirect_to_marketing_root_if_it_matches_lms_root(self):
+ response = self.client.get(reverse("root"))
+ assert response.status_code == 200
+
+ @override_settings(ENABLE_MKTG_SITE=True)
+ @override_settings(MKTG_URLS={'ROOT': 'https://home.foo.bar/'})
+ @override_settings(LMS_ROOT_URL='https://foo.bar/')
+ def test_index_will_redirect_to_new_root_if_mktg_site_is_enabled(self):
+ response = self.client.get(reverse("root"))
+ assert response.status_code == 302
+
def test_index_redirects_to_marketing_site_with_site_override(self):
""" Test index view redirects if MKTG_URLS['ROOT'] is set in SiteConfiguration """
self.use_site(self.site_other)
diff --git a/lms/djangoapps/branding/views.py b/lms/djangoapps/branding/views.py
index 711adb85afec..33c5813f16ff 100644
--- a/lms/djangoapps/branding/views.py
+++ b/lms/djangoapps/branding/views.py
@@ -42,7 +42,7 @@ def index(request):
# page to make it easier to browse for courses (and register)
if configuration_helpers.get_value(
'ALWAYS_REDIRECT_HOMEPAGE_TO_DASHBOARD_FOR_AUTHENTICATED_USER',
- settings.FEATURES.get('ALWAYS_REDIRECT_HOMEPAGE_TO_DASHBOARD_FOR_AUTHENTICATED_USER', True)):
+ getattr(settings, 'ALWAYS_REDIRECT_HOMEPAGE_TO_DASHBOARD_FOR_AUTHENTICATED_USER', True)):
return redirect('dashboard')
if use_catalog_mfe():
@@ -50,7 +50,7 @@ def index(request):
enable_mktg_site = configuration_helpers.get_value(
'ENABLE_MKTG_SITE',
- settings.FEATURES.get('ENABLE_MKTG_SITE', False)
+ getattr(settings, 'ENABLE_MKTG_SITE', False)
)
if enable_mktg_site:
@@ -58,7 +58,9 @@ def index(request):
'MKTG_URLS',
settings.MKTG_URLS
)
- return redirect(marketing_urls.get('ROOT'))
+ root_url = marketing_urls.get("ROOT")
+ if root_url != getattr(settings, "LMS_ROOT_URL", None):
+ return redirect(root_url)
domain = request.headers.get('Host')
diff --git a/lms/djangoapps/bulk_email/signals.py b/lms/djangoapps/bulk_email/signals.py
index 7402dca75482..da18a459aeaa 100644
--- a/lms/djangoapps/bulk_email/signals.py
+++ b/lms/djangoapps/bulk_email/signals.py
@@ -7,6 +7,7 @@
from eventtracking import tracker
from common.djangoapps.student.models import CourseEnrollment
+from openedx.core.djangoapps.content.course_overviews.models import CourseOverview
from openedx.core.djangoapps.user_api.accounts.signals import USER_RETIRE_MAILINGS
from edx_ace.signals import ACE_MESSAGE_SENT
@@ -27,7 +28,14 @@ def force_optout_all(sender, **kwargs): # lint-amnesty, pylint: disable=unused-
raise TypeError('Expected a User type, but received None.')
for enrollment in CourseEnrollment.objects.filter(user=user):
- Optout.objects.get_or_create(user=user, course_id=enrollment.course.id)
+ try:
+ Optout.objects.get_or_create(user=user, course_id=enrollment.course.id)
+ except CourseOverview.DoesNotExist:
+ log.warning(
+ f"CourseOverview not found for enrollment {enrollment.id} (user: {user.id}), "
+ f"skipping optout creation. This may mean the course was deleted."
+ )
+ continue
@receiver(ACE_MESSAGE_SENT)
diff --git a/lms/djangoapps/bulk_email/tests/test_signals.py b/lms/djangoapps/bulk_email/tests/test_signals.py
index 1a3715284b12..01ad9312da4c 100644
--- a/lms/djangoapps/bulk_email/tests/test_signals.py
+++ b/lms/djangoapps/bulk_email/tests/test_signals.py
@@ -10,9 +10,11 @@
from django.core.management import call_command
from django.urls import reverse
+from common.djangoapps.student.models import CourseEnrollment
from common.djangoapps.student.tests.factories import AdminFactory, CourseEnrollmentFactory, UserFactory
from lms.djangoapps.bulk_email.models import BulkEmailFlag, Optout
from lms.djangoapps.bulk_email.signals import force_optout_all
+from opaque_keys.edx.keys import CourseKey
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase # lint-amnesty, pylint: disable=wrong-import-order
from xmodule.modulestore.tests.factories import CourseFactory # lint-amnesty, pylint: disable=wrong-import-order
@@ -85,3 +87,41 @@ def test_optout_course(self):
assert len(mail.outbox) == 1
assert len(mail.outbox[0].to) == 1
assert mail.outbox[0].to[0] == self.instructor.email
+
+ @patch('lms.djangoapps.bulk_email.signals.log.warning')
+ def test_optout_handles_missing_course_overview(self, mock_log_warning):
+ """
+ Test that force_optout_all gracefully handles CourseEnrollments
+ with missing CourseOverview records
+ """
+ # Create a course key for a course that doesn't exist in CourseOverview
+ nonexistent_course_key = CourseKey.from_string('course-v1:TestX+Missing+2023')
+
+ # Create an enrollment with a course_id that doesn't have a CourseOverview
+ CourseEnrollment.objects.create(
+ user=self.student,
+ course_id=nonexistent_course_key,
+ mode='honor'
+ )
+
+ # Verify the orphaned enrollment exists
+ assert CourseEnrollment.objects.filter(
+ user=self.student,
+ course_id=nonexistent_course_key
+ ).exists()
+
+ force_optout_all(sender=self.__class__, user=self.student)
+
+ # Verify that a warning was logged for the missing CourseOverview
+ mock_log_warning.assert_called()
+ call_args = mock_log_warning.call_args[0][0]
+ assert "CourseOverview not found for enrollment" in call_args
+ assert f"user: {self.student.id}" in call_args
+ assert "skipping optout creation" in call_args
+
+ # Verify that optouts were created for valid courses only
+ valid_course_optouts = Optout.objects.filter(user=self.student, course_id=self.course.id)
+ missing_course_optouts = Optout.objects.filter(user=self.student, course_id=nonexistent_course_key)
+
+ assert valid_course_optouts.count() == 1
+ assert missing_course_optouts.count() == 0
diff --git a/lms/djangoapps/course_home_api/outline/serializers.py b/lms/djangoapps/course_home_api/outline/serializers.py
index cfa518138a95..f66012362327 100644
--- a/lms/djangoapps/course_home_api/outline/serializers.py
+++ b/lms/djangoapps/course_home_api/outline/serializers.py
@@ -62,6 +62,7 @@ def get_blocks(self, block): # pylint: disable=missing-function-docstring
'type': block_type,
'has_scheduled_content': block.get('has_scheduled_content'),
'hide_from_toc': block.get('hide_from_toc'),
+ 'is_preview': block.get('is_preview', False),
},
}
if 'special_exam_info' in self.context.get('extra_fields', []) and block.get('special_exam_info'):
diff --git a/lms/djangoapps/course_home_api/outline/tests/test_view.py b/lms/djangoapps/course_home_api/outline/tests/test_view.py
index 74e22e5fcc4b..6de5db83f94c 100644
--- a/lms/djangoapps/course_home_api/outline/tests/test_view.py
+++ b/lms/djangoapps/course_home_api/outline/tests/test_view.py
@@ -13,6 +13,7 @@
from django.test import override_settings
from django.urls import reverse
from edx_toggles.toggles.testutils import override_waffle_flag
+from opaque_keys.edx.keys import UsageKey
from cms.djangoapps.contentstore.outlines import update_outline_from_modulestore
from common.djangoapps.course_modes.models import CourseMode
@@ -43,6 +44,7 @@
BlockFactory,
CourseFactory
)
+from xmodule.partitions.partitions import ENROLLMENT_TRACK_PARTITION_ID
@ddt.ddt
@@ -484,6 +486,89 @@ def test_course_progress_analytics_disabled(self, mock_task):
self.client.get(self.url)
mock_task.assert_not_called()
+ # Tests for verified content preview functionality
+ # These tests cover the feature that allows audit learners to preview
+ # the structure of verified-only content without access to the content itself
+
+ @patch('lms.djangoapps.course_home_api.outline.views.learner_can_preview_verified_content')
+ def test_verified_content_preview_disabled_integration(self, mock_preview_function):
+ """Test that when verified preview is disabled, no preview markers are added."""
+ # Given a course with some Verified only sequences
+ with self.store.bulk_operations(self.course.id):
+ chapter = BlockFactory.create(category='chapter', parent_location=self.course.location)
+ sequential = BlockFactory.create(
+ category='sequential',
+ parent_location=chapter.location,
+ display_name='Verified Sequential',
+ group_access={ENROLLMENT_TRACK_PARTITION_ID: [2]} # restrict to verified only
+ )
+ update_outline_from_modulestore(self.course.id)
+
+ # ... where the preview feature is disabled
+ mock_preview_function.return_value = False
+
+ # When I access them as an audit user
+ CourseEnrollment.enroll(self.user, self.course.id, CourseMode.AUDIT)
+ response = self.client.get(self.url)
+
+ # Then I get a valid response back
+ assert response.status_code == 200
+
+ # ... with course_blocks populated
+ course_blocks = response.data['course_blocks']["blocks"]
+
+ # ... but with verified content omitted
+ assert str(sequential.location) not in course_blocks
+
+ # ... and no block has preview set to true
+ for block in course_blocks:
+ assert course_blocks[block].get('is_preview') is not True
+
+ @patch('lms.djangoapps.course_home_api.outline.views.learner_can_preview_verified_content')
+ @patch('lms.djangoapps.course_home_api.outline.views.get_user_course_outline')
+ def test_verified_content_preview_enabled_marks_previewable_content(self, mock_outline, mock_preview_enabled):
+ """Test that when verified preview is enabled, previewable sequences and chapters are marked."""
+ # Given a course with some Verified only sequences and some regular sequences
+ with self.store.bulk_operations(self.course.id):
+ chapter = BlockFactory.create(category='chapter', parent_location=self.course.location)
+ verified_sequential = BlockFactory.create(
+ category='sequential',
+ parent_location=chapter.location,
+ display_name='Verified Sequential',
+ )
+ regular_sequential = BlockFactory.create(
+ category='sequential',
+ parent_location=chapter.location,
+ display_name='Regular Sequential'
+ )
+ update_outline_from_modulestore(self.course.id)
+
+ # ... with an outline that correctly identifies previewable sequences
+ mock_course_outline = Mock()
+ mock_course_outline.sections = {Mock(usage_key=chapter.location)}
+ mock_course_outline.sequences = {verified_sequential.location, regular_sequential.location}
+ mock_course_outline.previewable_sequences = {verified_sequential.location}
+ mock_outline.return_value = mock_course_outline
+
+ # When I access them as an audit user with preview enabled
+ CourseEnrollment.enroll(self.user, self.course.id, CourseMode.AUDIT)
+ mock_preview_enabled.return_value = True
+
+ # Then I get a valid response back
+ response = self.client.get(self.url)
+ assert response.status_code == 200
+
+ # ... with course_blocks populated
+ course_blocks = response.data['course_blocks']["blocks"]
+
+ for block in course_blocks:
+ # ... and the verified only content is marked as preview only
+ if UsageKey.from_string(block) in mock_course_outline.previewable_sequences:
+ assert course_blocks[block].get('is_preview') is True
+ # ... and the regular content is not marked as preview
+ else:
+ assert course_blocks[block].get('is_preview') is False
+
@ddt.ddt
class SidebarBlocksTestViews(BaseCourseHomeTests):
diff --git a/lms/djangoapps/course_home_api/outline/views.py b/lms/djangoapps/course_home_api/outline/views.py
index 7c5307cba764..78d5767ffeed 100644
--- a/lms/djangoapps/course_home_api/outline/views.py
+++ b/lms/djangoapps/course_home_api/outline/views.py
@@ -36,7 +36,10 @@
)
from lms.djangoapps.course_home_api.utils import get_course_or_403
from lms.djangoapps.course_home_api.tasks import collect_progress_for_user_in_course
-from lms.djangoapps.course_home_api.toggles import send_course_progress_analytics_for_student_is_enabled
+from lms.djangoapps.course_home_api.toggles import (
+ learner_can_preview_verified_content,
+ send_course_progress_analytics_for_student_is_enabled,
+)
from lms.djangoapps.courseware.access import has_access
from lms.djangoapps.courseware.context_processor import user_timezone_locale_prefs
from lms.djangoapps.courseware.courses import get_course_date_blocks, get_course_info_section
@@ -209,6 +212,7 @@ def get(self, request, *args, **kwargs): # pylint: disable=too-many-statements
allow_anonymous = COURSE_ENABLE_UNENROLLED_ACCESS_FLAG.is_enabled(course_key)
allow_public = allow_anonymous and course.course_visibility == COURSE_VISIBILITY_PUBLIC
allow_public_outline = allow_anonymous and course.course_visibility == COURSE_VISIBILITY_PUBLIC_OUTLINE
+ allow_preview_of_verified_content = learner_can_preview_verified_content(course_key, request.user)
# User locale settings
user_timezone_locale = user_timezone_locale_prefs(request)
@@ -309,7 +313,8 @@ def get(self, request, *args, **kwargs): # pylint: disable=too-many-statements
# so this is a tiny first step in that migration.
if course_blocks:
user_course_outline = get_user_course_outline(
- course_key, request.user, datetime.now(tz=timezone.utc)
+ course_key, request.user, datetime.now(tz=timezone.utc),
+ preview_verified_content=allow_preview_of_verified_content
)
available_seq_ids = {str(usage_key) for usage_key in user_course_outline.sequences}
@@ -339,6 +344,19 @@ def get(self, request, *args, **kwargs): # pylint: disable=too-many-statements
)
] if 'children' in chapter_data else []
+ # For audit preview of verified content, we don't remove verified content.
+ # Instead, we mark it as preview so the frontend can handle it appropriately.
+ if allow_preview_of_verified_content:
+ previewable_sequences = {str(usage_key) for usage_key in user_course_outline.previewable_sequences}
+
+ # Iterate through course_blocks to mark previewable sequences and chapters
+ for chapter_data in course_blocks['children']:
+ if chapter_data['id'] in previewable_sequences:
+ chapter_data['is_preview'] = True
+ for seq_data in chapter_data.get('children', []):
+ if seq_data['id'] in previewable_sequences:
+ seq_data['is_preview'] = True
+
user_has_passing_grade = False
if not request.user.is_anonymous:
user_grade = CourseGradeFactory().read(request.user, course)
diff --git a/lms/djangoapps/course_home_api/progress/api.py b/lms/djangoapps/course_home_api/progress/api.py
index b2a8634c59f4..f89ecd3d2596 100644
--- a/lms/djangoapps/course_home_api/progress/api.py
+++ b/lms/djangoapps/course_home_api/progress/api.py
@@ -2,14 +2,226 @@
Python APIs exposed for the progress tracking functionality of the course home API.
"""
+from __future__ import annotations
+
from django.contrib.auth import get_user_model
from opaque_keys.edx.keys import CourseKey
+from openedx.core.lib.grade_utils import round_away_from_zero
+from xmodule.graders import ShowCorrectness
+from datetime import datetime, timezone
from lms.djangoapps.courseware.courses import get_course_blocks_completion_summary
+from dataclasses import dataclass, field
User = get_user_model()
+@dataclass
+class _AssignmentBucket:
+ """Holds scores and visibility info for one assignment type.
+
+ Attributes:
+ assignment_type: Full assignment type name from the grading policy (for example, "Homework").
+ num_total: The total number of assignments expected to contribute to the grade before any
+ drop-lowest rules are applied.
+ last_grade_publish_date: The most recent date when grades for all assignments of assignment_type
+ are released and included in the final grade.
+ scores: Per-subsection fractional scores (each value is ``earned / possible`` and falls in
+ the range 0–1). While awaiting published content we pad the list with zero placeholders
+ so that its length always matches ``num_total`` until real scores replace them.
+ visibilities: Mirrors ``scores`` index-for-index and records whether each subsection's
+ correctness feedback is visible to the learner (``True``), hidden (``False``), or not
+ yet populated (``None`` when the entry is a placeholder).
+ included: Tracks whether each subsection currently counts toward the learner's grade as
+ determined by ``SubsectionGrade.show_grades``. Values follow the same convention as
+ ``visibilities`` (``True`` / ``False`` / ``None`` placeholders).
+ assignments_created: Count of real subsections inserted into the bucket so far. Once this
+ reaches ``num_total``, all placeholder entries have been replaced with actual data.
+ """
+ assignment_type: str
+ num_total: int
+ last_grade_publish_date: datetime
+ scores: list[float] = field(default_factory=list)
+ visibilities: list[bool | None] = field(default_factory=list)
+ included: list[bool | None] = field(default_factory=list)
+ assignments_created: int = 0
+
+ @classmethod
+ def with_placeholders(cls, assignment_type: str, num_total: int, now: datetime):
+ """Create a bucket prefilled with placeholder (empty) entries."""
+ return cls(
+ assignment_type=assignment_type,
+ num_total=num_total,
+ last_grade_publish_date=now,
+ scores=[0] * num_total,
+ visibilities=[None] * num_total,
+ included=[None] * num_total,
+ )
+
+ def add_subsection(self, score: float, is_visible: bool, is_included: bool):
+ """Add a subsection’s score and visibility, replacing a placeholder if space remains."""
+ if self.assignments_created < self.num_total:
+ if self.scores:
+ self.scores.pop(0)
+ if self.visibilities:
+ self.visibilities.pop(0)
+ if self.included:
+ self.included.pop(0)
+ self.scores.append(score)
+ self.visibilities.append(is_visible)
+ self.included.append(is_included)
+ self.assignments_created += 1
+
+ def drop_lowest(self, num_droppable: int):
+ """Remove the lowest scoring subsections, up to the provided num_droppable."""
+ while num_droppable > 0 and self.scores:
+ idx = self.scores.index(min(self.scores))
+ self.scores.pop(idx)
+ self.visibilities.pop(idx)
+ self.included.pop(idx)
+ num_droppable -= 1
+
+ def hidden_state(self) -> str:
+ """Return whether kept scores are all, some, or none hidden."""
+ if not self.visibilities:
+ return 'none'
+ all_hidden = all(v is False for v in self.visibilities)
+ some_hidden = any(v is False for v in self.visibilities)
+ if all_hidden:
+ return 'all'
+ if some_hidden:
+ return 'some'
+ return 'none'
+
+ def averages(self) -> tuple[float, float]:
+ """Compute visible and included averages over kept scores.
+
+ Visible average uses only grades with visibility flag True in numerator; denominator is total
+ number of kept scores (mirrors legacy behavior). Included average uses only scores that are
+ marked included (show_grades True) in numerator with same denominator.
+
+ Returns:
+ (earned_visible, earned_all) tuple of floats (0-1 each).
+ """
+ if not self.scores:
+ return 0.0, 0.0
+ visible_scores = [s for i, s in enumerate(self.scores) if self.visibilities[i]]
+ included_scores = [s for i, s in enumerate(self.scores) if self.included[i]]
+ earned_visible = (sum(visible_scores) / len(self.scores)) if self.scores else 0.0
+ earned_all = (sum(included_scores) / len(self.scores)) if self.scores else 0.0
+ return earned_visible, earned_all
+
+
+class _AssignmentTypeGradeAggregator:
+ """Collects and aggregates subsection grades by assignment type."""
+
+ def __init__(self, course_grade, grading_policy: dict, has_staff_access: bool):
+ """Initialize with course grades, grading policy, and staff access flag."""
+ self.course_grade = course_grade
+ self.grading_policy = grading_policy
+ self.has_staff_access = has_staff_access
+ self.now = datetime.now(timezone.utc)
+ self.policy_map = self._build_policy_map()
+ self.buckets: dict[str, _AssignmentBucket] = {}
+
+ def _build_policy_map(self) -> dict:
+ """Convert grading policy into a lookup of assignment type → policy info."""
+ policy_map = {}
+ for policy in self.grading_policy.get('GRADER', []):
+ policy_map[policy.get('type')] = {
+ 'weight': policy.get('weight', 0.0),
+ 'short_label': policy.get('short_label', ''),
+ 'num_droppable': policy.get('drop_count', 0),
+ 'num_total': policy.get('min_count', 0),
+ }
+ return policy_map
+
+ def _bucket_for(self, assignment_type: str) -> _AssignmentBucket:
+ """Get or create a score bucket for the given assignment type."""
+ bucket = self.buckets.get(assignment_type)
+ if bucket is None:
+ num_total = self.policy_map.get(assignment_type, {}).get('num_total', 0) or 0
+ bucket = _AssignmentBucket.with_placeholders(assignment_type, num_total, self.now)
+ self.buckets[assignment_type] = bucket
+ return bucket
+
+ def collect(self):
+ """Gather subsection grades into their respective assignment buckets."""
+ for chapter in self.course_grade.chapter_grades.values():
+ for subsection_grade in chapter.get('sections', []):
+ if not getattr(subsection_grade, 'graded', False):
+ continue
+ assignment_type = getattr(subsection_grade, 'format', '') or ''
+ if not assignment_type:
+ continue
+ graded_total = getattr(subsection_grade, 'graded_total', None)
+ earned = getattr(graded_total, 'earned', 0.0) if graded_total else 0.0
+ possible = getattr(graded_total, 'possible', 0.0) if graded_total else 0.0
+ earned = 0.0 if earned is None else earned
+ possible = 0.0 if possible is None else possible
+ score = (earned / possible) if possible else 0.0
+ is_visible = ShowCorrectness.correctness_available(
+ subsection_grade.show_correctness, subsection_grade.due, self.has_staff_access
+ )
+ is_included = subsection_grade.show_grades(self.has_staff_access)
+ bucket = self._bucket_for(assignment_type)
+ bucket.add_subsection(score, is_visible, is_included)
+ visibilities_with_due_dates = [ShowCorrectness.PAST_DUE, ShowCorrectness.NEVER_BUT_INCLUDE_GRADE]
+ if subsection_grade.show_correctness in visibilities_with_due_dates:
+ if subsection_grade.due and subsection_grade.due > bucket.last_grade_publish_date:
+ bucket.last_grade_publish_date = subsection_grade.due
+
+ def build_results(self) -> dict:
+ """Apply drops, compute averages, and return aggregated results and total grade."""
+ final_grades = 0.0
+ rows = []
+ for assignment_type, bucket in self.buckets.items():
+ policy = self.policy_map.get(assignment_type, {})
+ bucket.drop_lowest(policy.get('num_droppable', 0))
+ earned_visible, earned_all = bucket.averages()
+ weight = policy.get('weight', 0.0)
+ short_label = policy.get('short_label', '')
+ row = {
+ 'type': assignment_type,
+ 'weight': weight,
+ 'average_grade': round_away_from_zero(earned_visible, 4),
+ 'weighted_grade': round_away_from_zero(earned_visible * weight, 4),
+ 'short_label': short_label,
+ 'num_droppable': policy.get('num_droppable', 0),
+ 'last_grade_publish_date': bucket.last_grade_publish_date,
+ 'has_hidden_contribution': bucket.hidden_state(),
+ }
+ final_grades += earned_all * weight
+ rows.append(row)
+ rows.sort(key=lambda r: r['weight'])
+ return {'results': rows, 'final_grades': round_away_from_zero(final_grades, 4)}
+
+ def run(self) -> dict:
+ """Execute full pipeline (collect + aggregate) returning final payload."""
+ self.collect()
+ return self.build_results()
+
+
+def aggregate_assignment_type_grade_summary(
+ course_grade,
+ grading_policy: dict,
+ has_staff_access: bool = False,
+) -> dict:
+ """
+ Aggregate subsection grades by assignment type and return summary data.
+ Args:
+ course_grade: CourseGrade object containing chapter and subsection grades.
+ grading_policy: Dictionary representing the course's grading policy.
+ has_staff_access: Boolean indicating if the user has staff access to view all grades.
+ Returns:
+ Dictionary with keys:
+ results: list of per-assignment-type summary dicts
+ final_grades: overall weighted contribution (float, 4 decimal rounding)
+ """
+ aggregator = _AssignmentTypeGradeAggregator(course_grade, grading_policy, has_staff_access)
+ return aggregator.run()
+
+
def calculate_progress_for_learner_in_course(course_key: CourseKey, user: User) -> dict:
"""
Calculate a given learner's progress in the specified course run.
diff --git a/lms/djangoapps/course_home_api/progress/serializers.py b/lms/djangoapps/course_home_api/progress/serializers.py
index 6bdc204434af..c48660a41c6a 100644
--- a/lms/djangoapps/course_home_api/progress/serializers.py
+++ b/lms/djangoapps/course_home_api/progress/serializers.py
@@ -26,6 +26,7 @@ class SubsectionScoresSerializer(ReadOnlySerializer):
assignment_type = serializers.CharField(source='format')
block_key = serializers.SerializerMethodField()
display_name = serializers.CharField()
+ due = serializers.DateTimeField(allow_null=True)
has_graded_assignment = serializers.BooleanField(source='graded')
override = serializers.SerializerMethodField()
learner_has_access = serializers.SerializerMethodField()
@@ -127,6 +128,20 @@ class VerificationDataSerializer(ReadOnlySerializer):
status_date = serializers.DateTimeField()
+class AssignmentTypeScoresSerializer(ReadOnlySerializer):
+ """
+ Serializer for aggregated scores per assignment type.
+ """
+ type = serializers.CharField()
+ weight = serializers.FloatField()
+ average_grade = serializers.FloatField()
+ weighted_grade = serializers.FloatField()
+ last_grade_publish_date = serializers.DateTimeField()
+ has_hidden_contribution = serializers.CharField()
+ short_label = serializers.CharField()
+ num_droppable = serializers.IntegerField()
+
+
class ProgressTabSerializer(VerifiedModeSerializer):
"""
Serializer for progress tab
@@ -146,3 +161,5 @@ class ProgressTabSerializer(VerifiedModeSerializer):
user_has_passing_grade = serializers.BooleanField()
verification_data = VerificationDataSerializer()
disable_progress_graph = serializers.BooleanField()
+ assignment_type_grade_summary = AssignmentTypeScoresSerializer(many=True)
+ final_grades = serializers.FloatField()
diff --git a/lms/djangoapps/course_home_api/progress/tests/test_api.py b/lms/djangoapps/course_home_api/progress/tests/test_api.py
index 30d8d9059eaa..51e7dd68286e 100644
--- a/lms/djangoapps/course_home_api/progress/tests/test_api.py
+++ b/lms/djangoapps/course_home_api/progress/tests/test_api.py
@@ -6,7 +6,80 @@
from django.test import TestCase
-from lms.djangoapps.course_home_api.progress.api import calculate_progress_for_learner_in_course
+from lms.djangoapps.course_home_api.progress.api import (
+ calculate_progress_for_learner_in_course,
+ aggregate_assignment_type_grade_summary,
+)
+from xmodule.graders import ShowCorrectness
+from datetime import datetime, timedelta, timezone
+from types import SimpleNamespace
+
+
+def _make_subsection(fmt, earned, possible, show_corr, *, due_delta_days=None):
+ """Build a lightweight subsection object for testing aggregation scenarios."""
+ graded_total = SimpleNamespace(earned=earned, possible=possible)
+ due = None
+ if due_delta_days is not None:
+ due = datetime.now(timezone.utc) + timedelta(days=due_delta_days)
+ return SimpleNamespace(
+ graded=True,
+ format=fmt,
+ graded_total=graded_total,
+ show_correctness=show_corr,
+ due=due,
+ show_grades=lambda staff: True,
+ )
+
+
+_AGGREGATION_SCENARIOS = [
+ (
+ 'all_visible_always',
+ {'type': 'Homework', 'weight': 1.0, 'drop_count': 0, 'min_count': 2, 'short_label': 'HW'},
+ [
+ _make_subsection('Homework', 1, 1, ShowCorrectness.ALWAYS),
+ _make_subsection('Homework', 0.5, 1, ShowCorrectness.ALWAYS),
+ ],
+ {'avg': 0.75, 'weighted': 0.75, 'hidden': 'none', 'final': 0.75},
+ ),
+ (
+ 'some_hidden_never_but_include',
+ {'type': 'Exam', 'weight': 1.0, 'drop_count': 0, 'min_count': 2, 'short_label': 'EX'},
+ [
+ _make_subsection('Exam', 1, 1, ShowCorrectness.ALWAYS),
+ _make_subsection('Exam', 0.5, 1, ShowCorrectness.NEVER_BUT_INCLUDE_GRADE),
+ ],
+ {'avg': 0.5, 'weighted': 0.5, 'hidden': 'some', 'final': 0.75},
+ ),
+ (
+ 'all_hidden_never_but_include',
+ {'type': 'Quiz', 'weight': 1.0, 'drop_count': 0, 'min_count': 2, 'short_label': 'QZ'},
+ [
+ _make_subsection('Quiz', 0.4, 1, ShowCorrectness.NEVER_BUT_INCLUDE_GRADE),
+ _make_subsection('Quiz', 0.6, 1, ShowCorrectness.NEVER_BUT_INCLUDE_GRADE),
+ ],
+ {'avg': 0.0, 'weighted': 0.0, 'hidden': 'all', 'final': 0.5},
+ ),
+ (
+ 'past_due_mixed_visibility',
+ {'type': 'Lab', 'weight': 1.0, 'drop_count': 0, 'min_count': 2, 'short_label': 'LB'},
+ [
+ _make_subsection('Lab', 0.8, 1, ShowCorrectness.PAST_DUE, due_delta_days=-1),
+ _make_subsection('Lab', 0.2, 1, ShowCorrectness.PAST_DUE, due_delta_days=+3),
+ ],
+ {'avg': 0.4, 'weighted': 0.4, 'hidden': 'some', 'final': 0.5},
+ ),
+ (
+ 'drop_lowest_keeps_high_scores',
+ {'type': 'Project', 'weight': 1.0, 'drop_count': 2, 'min_count': 4, 'short_label': 'PR'},
+ [
+ _make_subsection('Project', 1, 1, ShowCorrectness.ALWAYS),
+ _make_subsection('Project', 1, 1, ShowCorrectness.ALWAYS),
+ _make_subsection('Project', 0, 1, ShowCorrectness.ALWAYS),
+ _make_subsection('Project', 0, 1, ShowCorrectness.ALWAYS),
+ ],
+ {'avg': 1.0, 'weighted': 1.0, 'hidden': 'none', 'final': 1.0},
+ ),
+]
class ProgressApiTests(TestCase):
@@ -73,3 +146,37 @@ def test_calculate_progress_for_learner_in_course_summary_empty(self, mock_get_s
results = calculate_progress_for_learner_in_course("some_course", "some_user")
assert not results
+
+ def test_aggregate_assignment_type_grade_summary_scenarios(self):
+ """
+ A test to verify functionality of aggregate_assignment_type_grade_summary.
+ 1. Test visibility modes (always, never but include grade, past due)
+ 2. Test drop-lowest behavior
+ 3. Test weighting behavior
+ 4. Test final grade calculation
+ 5. Test average grade calculation
+ 6. Test weighted grade calculation
+ 7. Test has_hidden_contribution calculation
+ """
+
+ for case_name, policy, subsections, expected in _AGGREGATION_SCENARIOS:
+ with self.subTest(case_name=case_name):
+ course_grade = SimpleNamespace(chapter_grades={'chapter': {'sections': subsections}})
+ grading_policy = {'GRADER': [policy]}
+
+ result = aggregate_assignment_type_grade_summary(
+ course_grade,
+ grading_policy,
+ has_staff_access=False,
+ )
+
+ assert 'results' in result and 'final_grades' in result
+ assert result['final_grades'] == expected['final']
+ assert len(result['results']) == 1
+
+ row = result['results'][0]
+ assert row['type'] == policy['type'], case_name
+ assert row['average_grade'] == expected['avg']
+ assert row['weighted_grade'] == expected['weighted']
+ assert row['has_hidden_contribution'] == expected['hidden']
+ assert row['num_droppable'] == policy['drop_count']
diff --git a/lms/djangoapps/course_home_api/progress/tests/test_views.py b/lms/djangoapps/course_home_api/progress/tests/test_views.py
index d13ebec29c21..8012e11675f1 100644
--- a/lms/djangoapps/course_home_api/progress/tests/test_views.py
+++ b/lms/djangoapps/course_home_api/progress/tests/test_views.py
@@ -282,8 +282,8 @@ def test_url_hidden_if_subsection_hide_after_due(self):
assert hide_after_due_subsection['url'] is None
@ddt.data(
- (True, 0.7), # midterm and final are visible to staff
- (False, 0.3), # just the midterm is visible to learners
+ (True, 0.72), # lab, midterm and final are visible to staff
+ (False, 0.32), # Only lab and midterm is visible to learners
)
@ddt.unpack
def test_course_grade_considers_subsection_grade_visibility(self, is_staff, expected_percent):
@@ -301,14 +301,18 @@ def test_course_grade_considers_subsection_grade_visibility(self, is_staff, expe
never = self.add_subsection_with_problem(format='Homework', show_correctness='never')
always = self.add_subsection_with_problem(format='Midterm Exam', show_correctness='always')
past_due = self.add_subsection_with_problem(format='Final Exam', show_correctness='past_due', due=tomorrow)
+ never_but_show_grade = self.add_subsection_with_problem(
+ format='Lab', show_correctness='never_but_include_grade'
+ )
answer_problem(self.course, get_mock_request(self.user), never)
answer_problem(self.course, get_mock_request(self.user), always)
answer_problem(self.course, get_mock_request(self.user), past_due)
+ answer_problem(self.course, get_mock_request(self.user), never_but_show_grade)
# First, confirm the grade in the database - it should never change based on user state.
# This is midterm and final and a single problem added together.
- assert CourseGradeFactory().read(self.user, self.course).percent == 0.72
+ assert CourseGradeFactory().read(self.user, self.course).percent == 0.73
response = self.client.get(self.url)
assert response.status_code == 200
diff --git a/lms/djangoapps/course_home_api/progress/views.py b/lms/djangoapps/course_home_api/progress/views.py
index 3783c19061dc..54e71df48cc5 100644
--- a/lms/djangoapps/course_home_api/progress/views.py
+++ b/lms/djangoapps/course_home_api/progress/views.py
@@ -13,8 +13,11 @@
from rest_framework.response import Response
from xmodule.modulestore.django import modulestore
+from xmodule.graders import ShowCorrectness
from common.djangoapps.student.models import CourseEnrollment
from lms.djangoapps.course_home_api.progress.serializers import ProgressTabSerializer
+from lms.djangoapps.course_home_api.progress.api import aggregate_assignment_type_grade_summary
+
from lms.djangoapps.course_home_api.toggles import course_home_mfe_progress_tab_is_active
from lms.djangoapps.courseware.access import has_access, has_ccx_coach_role
from lms.djangoapps.course_blocks.api import get_course_blocks
@@ -99,6 +102,7 @@ class ProgressTabView(RetrieveAPIView):
assignment_type: (str) the format, if any, of the Subsection (Homework, Exam, etc)
block_key: (str) the key of the given subsection block
display_name: (str) a str of what the name of the Subsection is for displaying on the site
+ due: (str or None) the due date of the subsection in ISO 8601 format, or None if no due date is set
has_graded_assignment: (bool) whether or not the Subsection is a graded assignment
learner_has_access: (bool) whether the learner has access to the subsection (could be FBE gated)
num_points_earned: (int) the amount of points the user has earned for the given subsection
@@ -175,6 +179,18 @@ def _get_student_user(self, request, course_key, student_id, is_staff):
except User.DoesNotExist as exc:
raise Http404 from exc
+ def _visible_section_scores(self, course_grade):
+ """Return only those chapter/section scores that are visible to the learner."""
+ visible_chapters = []
+ for chapter in course_grade.chapter_grades.values():
+ filtered_sections = [
+ subsection
+ for subsection in chapter["sections"]
+ if getattr(subsection, "show_correctness", None) != ShowCorrectness.NEVER_BUT_INCLUDE_GRADE
+ ]
+ visible_chapters.append({**chapter, "sections": filtered_sections})
+ return visible_chapters
+
def get(self, request, *args, **kwargs):
course_key_string = kwargs.get('course_key_string')
course_key = CourseKey.from_string(course_key_string)
@@ -245,6 +261,16 @@ def get(self, request, *args, **kwargs):
access_expiration = get_access_expiration_data(request.user, course_overview)
+ # Aggregations delegated to helper functions for reuse and testability
+ assignment_type_grade_summary = aggregate_assignment_type_grade_summary(
+ course_grade,
+ grading_policy,
+ has_staff_access=is_staff,
+ )
+
+ # Filter out section scores to only have those that are visible to the user
+ section_scores = self._visible_section_scores(course_grade)
+
data = {
'access_expiration': access_expiration,
'certificate_data': get_cert_data(student, course, enrollment_mode, course_grade),
@@ -255,12 +281,14 @@ def get(self, request, *args, **kwargs):
'enrollment_mode': enrollment_mode,
'grading_policy': grading_policy,
'has_scheduled_content': has_scheduled_content,
- 'section_scores': list(course_grade.chapter_grades.values()),
+ 'section_scores': section_scores,
'studio_url': get_studio_url(course, 'settings/grading'),
'username': username,
'user_has_passing_grade': user_has_passing_grade,
'verification_data': verification_data,
'disable_progress_graph': disable_progress_graph,
+ 'assignment_type_grade_summary': assignment_type_grade_summary["results"],
+ 'final_grades': assignment_type_grade_summary["final_grades"],
}
context = self.get_serializer_context()
context['staff_access'] = is_staff
diff --git a/lms/djangoapps/course_home_api/tests/test_toggles.py b/lms/djangoapps/course_home_api/tests/test_toggles.py
new file mode 100644
index 000000000000..46ab545d0ade
--- /dev/null
+++ b/lms/djangoapps/course_home_api/tests/test_toggles.py
@@ -0,0 +1,155 @@
+"""
+Tests for Course Home API toggles.
+"""
+
+from unittest.mock import Mock, patch
+
+from django.test import TestCase
+from opaque_keys.edx.keys import CourseKey
+
+from common.djangoapps.course_modes.models import CourseMode
+from common.djangoapps.course_modes.tests.factories import CourseModeFactory
+
+from ..toggles import learner_can_preview_verified_content
+
+
+class TestLearnerCanPreviewVerifiedContent(TestCase):
+ """Test cases for learner_can_preview_verified_content function."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ self.course_key = CourseKey.from_string("course-v1:TestX+CS101+2024")
+ self.user = Mock()
+
+ # Set up patchers
+ self.feature_enabled_patcher = patch(
+ "lms.djangoapps.course_home_api.toggles.audit_learner_verified_preview_is_enabled"
+ )
+ self.verified_mode_for_course_patcher = patch(
+ "common.djangoapps.course_modes.models.CourseMode.verified_mode_for_course"
+ )
+ self.get_enrollment_patcher = patch(
+ "common.djangoapps.student.models.CourseEnrollment.get_enrollment"
+ )
+
+ # Course set up with verified, professional, and audit modes
+ self.verified_mode = CourseModeFactory(
+ course_id=self.course_key,
+ mode_slug=CourseMode.VERIFIED,
+ mode_display_name="Verified",
+ )
+ self.professional_mode = CourseModeFactory(
+ course_id=self.course_key,
+ mode_slug=CourseMode.PROFESSIONAL,
+ mode_display_name="Professional",
+ )
+ self.audit_mode = CourseModeFactory(
+ course_id=self.course_key,
+ mode_slug=CourseMode.AUDIT,
+ mode_display_name="Audit",
+ )
+ self.course_modes_dict = {
+ "audit": self.audit_mode,
+ "verified": self.verified_mode,
+ "professional": self.professional_mode,
+ }
+
+ # Start patchers
+ self.mock_feature_enabled = self.feature_enabled_patcher.start()
+ self.mock_verified_mode_for_course = (
+ self.verified_mode_for_course_patcher.start()
+ )
+ self.mock_get_enrollment = self.get_enrollment_patcher.start()
+
+ def _enroll_user(self, mode):
+ """Helper method to set up user enrollment mock."""
+ mock_enrollment = Mock()
+ mock_enrollment.mode = mode
+ self.mock_get_enrollment.return_value = mock_enrollment
+
+ def tearDown(self):
+ """Clean up patchers."""
+ self.feature_enabled_patcher.stop()
+ self.verified_mode_for_course_patcher.stop()
+ self.get_enrollment_patcher.stop()
+
+ def test_all_conditions_met_returns_true(self):
+ """Test that function returns True when all conditions are met."""
+ # Given the feature is enabled, course has verified mode, and user is enrolled as audit
+ self.mock_feature_enabled.return_value = True
+ self.mock_verified_mode_for_course.return_value = self.course_modes_dict[
+ "professional"
+ ]
+ self._enroll_user(CourseMode.AUDIT)
+
+ # When I check if the learner can preview verified content
+ result = learner_can_preview_verified_content(self.course_key, self.user)
+
+ # Then the result should be True
+ self.assertTrue(result)
+
+ def test_feature_disabled_returns_false(self):
+ """Test that function returns False when feature is disabled."""
+ # Given the feature is disabled
+ self.mock_feature_enabled.return_value = False
+
+ # ... even if all other conditions are met
+ self.mock_verified_mode_for_course.return_value = self.course_modes_dict[
+ "professional"
+ ]
+ self._enroll_user(CourseMode.AUDIT)
+
+ # When I check if the learner can preview verified content
+ result = learner_can_preview_verified_content(self.course_key, self.user)
+
+ # Then the result should be False
+ self.assertFalse(result)
+
+ def test_no_verified_mode_returns_false(self):
+ """Test that function returns False when course has no verified mode."""
+ # Given the course does not have a verified mode
+ self.mock_verified_mode_for_course.return_value = None
+
+ # ... even if all other conditions are met
+ self.mock_feature_enabled.return_value = True
+ self._enroll_user(CourseMode.AUDIT)
+
+ # When I check if the learner can preview verified content
+ result = learner_can_preview_verified_content(self.course_key, self.user)
+
+ # Then the result should be False
+ self.assertFalse(result)
+
+ def test_no_enrollment_returns_false(self):
+ """Test that function returns False when user is not enrolled."""
+ # Given the user is unenrolled
+ self.mock_get_enrollment.return_value = None
+
+ # ... even if all other conditions are met
+ self.mock_feature_enabled.return_value = True
+ self.mock_verified_mode_for_course.return_value = self.course_modes_dict[
+ "professional"
+ ]
+
+ # When I check if the learner can preview verified content
+ result = learner_can_preview_verified_content(self.course_key, self.user)
+
+ # Then the result should be False
+ self.assertFalse(result)
+
+ def test_verified_enrollment_returns_false(self):
+ """Test that function returns False when user is enrolled in verified mode."""
+ # Given the user is not enrolled as audit
+ self._enroll_user(CourseMode.VERIFIED)
+
+ # ... even if all other conditions are met
+ self.mock_feature_enabled.return_value = True
+ self.mock_verified_mode_for_course.return_value = self.course_modes_dict[
+ "professional"
+ ]
+
+ # When I check if the learner can preview verified content
+ result = learner_can_preview_verified_content(self.course_key, self.user)
+
+ # Then the result should be False
+ self.assertFalse(result)
diff --git a/lms/djangoapps/course_home_api/toggles.py b/lms/djangoapps/course_home_api/toggles.py
index 052862796c75..1f2d32b87e96 100644
--- a/lms/djangoapps/course_home_api/toggles.py
+++ b/lms/djangoapps/course_home_api/toggles.py
@@ -3,6 +3,9 @@
"""
from openedx.core.djangoapps.waffle_utils import CourseWaffleFlag
+from openedx.core.lib.cache_utils import request_cached
+from common.djangoapps.course_modes.models import CourseMode
+from common.djangoapps.student.models import CourseEnrollment
WAFFLE_FLAG_NAMESPACE = 'course_home'
@@ -51,6 +54,21 @@
)
+# Waffle flag to enable audit learner preview of course structure visible to verified learners.
+#
+# .. toggle_name: course_home.audit_learner_verified_preview
+# .. toggle_implementation: CourseWaffleFlag
+# .. toggle_default: False
+# .. toggle_description: Where enabled, audit learners can see the presence of the sections / units
+# otherwise restricted to verified learners. The content itself remains inaccessible.
+# .. toggle_use_cases: open_edx
+# .. toggle_creation_date: 2025-11-07
+# .. toggle_target_removal_date: None
+COURSE_HOME_AUDIT_LEARNER_VERIFIED_PREVIEW = CourseWaffleFlag(
+ f'{WAFFLE_FLAG_NAMESPACE}.audit_learner_verified_preview', __name__
+)
+
+
def course_home_mfe_progress_tab_is_active(course_key):
# Avoiding a circular dependency
from .models import DisableProgressPageStackedConfig
@@ -73,3 +91,40 @@ def send_course_progress_analytics_for_student_is_enabled(course_key):
Returns True if the course completion analytics feature is enabled for a given course.
"""
return COURSE_HOME_SEND_COURSE_PROGRESS_ANALYTICS_FOR_STUDENT.is_enabled(course_key)
+
+
+def audit_learner_verified_preview_is_enabled(course_key):
+ """
+ Returns True if the audit learner verified preview feature is enabled for a given course.
+ """
+ return COURSE_HOME_AUDIT_LEARNER_VERIFIED_PREVIEW.is_enabled(course_key)
+
+
+@request_cached()
+def learner_can_preview_verified_content(course_key, user):
+ """
+ Determine if an audit learner can preview verified content in a course.
+
+ Args:
+ course_key: The course identifier.
+ user: The user object
+ Returns:
+ True if the learner can preview verified content, False otherwise.
+ """
+ # To preview verified content, the feature must be enabled for the course...
+ feature_enabled = audit_learner_verified_preview_is_enabled(course_key)
+ if not feature_enabled:
+ return False
+
+ # ... the course must have a verified mode
+ course_has_verified_mode = CourseMode.verified_mode_for_course(course_key)
+ if not course_has_verified_mode:
+ return False
+
+ # ... and the user must be enrolled as audit
+ enrollment = CourseEnrollment.get_enrollment(user, course_key)
+ user_enrolled_as_audit = enrollment is not None and enrollment.mode == CourseMode.AUDIT
+ if not user_enrolled_as_audit:
+ return False
+
+ return True
diff --git a/lms/djangoapps/courseware/tests/test_views.py b/lms/djangoapps/courseware/tests/test_views.py
index 4e3d1be9bddc..4a3f16523ce5 100644
--- a/lms/djangoapps/courseware/tests/test_views.py
+++ b/lms/djangoapps/courseware/tests/test_views.py
@@ -79,6 +79,7 @@
COURSEWARE_MICROFRONTEND_ENABLE_NAVIGATION_SIDEBAR,
COURSEWARE_MICROFRONTEND_SEARCH_ENABLED,
COURSEWARE_OPTIMIZED_RENDER_XBLOCK,
+ ENABLE_UNIFIED_SITE_AND_TRANSLATION_LANGUAGE,
)
from completion.waffle import ENABLE_COMPLETION_TRACKING_SWITCH
from lms.djangoapps.courseware.user_state_client import DjangoXBlockUserStateClient
@@ -1781,6 +1782,14 @@ def assert_progress_page_show_grades(self, response, show_correctness, due_date,
(ShowCorrectness.PAST_DUE, TODAY, True),
(ShowCorrectness.PAST_DUE, TOMORROW, False),
(ShowCorrectness.PAST_DUE, TOMORROW, True),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, None, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, None, True),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, YESTERDAY, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, YESTERDAY, True),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TODAY, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TODAY, True),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TOMORROW, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TOMORROW, True),
)
@ddt.unpack
def test_progress_page_no_problem_scores(self, show_correctness, due_date_name, graded):
@@ -1821,6 +1830,14 @@ def test_progress_page_no_problem_scores(self, show_correctness, due_date_name,
(ShowCorrectness.PAST_DUE, TODAY, True, True),
(ShowCorrectness.PAST_DUE, TOMORROW, False, False),
(ShowCorrectness.PAST_DUE, TOMORROW, True, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, None, False, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, None, True, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, YESTERDAY, False, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, YESTERDAY, True, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TODAY, False, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TODAY, True, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TOMORROW, False, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TOMORROW, True, False),
)
@ddt.unpack
def test_progress_page_hide_scores_from_learner(self, show_correctness, due_date_name, graded, show_grades):
@@ -1873,11 +1890,20 @@ def test_progress_page_hide_scores_from_learner(self, show_correctness, due_date
(ShowCorrectness.PAST_DUE, TODAY, True, True),
(ShowCorrectness.PAST_DUE, TOMORROW, False, True),
(ShowCorrectness.PAST_DUE, TOMORROW, True, True),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, None, False, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, None, True, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, YESTERDAY, False, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, YESTERDAY, True, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TODAY, False, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TODAY, True, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TOMORROW, False, False),
+ (ShowCorrectness.NEVER_BUT_INCLUDE_GRADE, TOMORROW, True, False),
)
@ddt.unpack
def test_progress_page_hide_scores_from_staff(self, show_correctness, due_date_name, graded, show_grades):
"""
- Test that problem scores are hidden from staff viewing a learner's progress page only if show_correctness=never.
+ Test that problem scores are hidden from staff viewing a learner's progress page only if show_correctness is
+ never or never_but_include_grade.
"""
due_date = self.DATES[due_date_name]
self.setup_course(show_correctness=show_correctness, due_date=due_date, graded=graded)
@@ -3421,3 +3447,25 @@ def test_course_about_redirect_to_mfe(self, catalog_mfe_enabled, expected_redire
assert response.url == "http://example.com/catalog/courses/{}/about".format(self.course.id)
else:
assert response.status_code == 200
+
+
+@ddt.ddt
+class UnifiedSiteAndTranslationLanguageEnabledViewTests(TestCase):
+ """
+ Tests for the unified_site_and_translation_language_enabled view
+ """
+ @override_waffle_flag(ENABLE_UNIFIED_SITE_AND_TRANSLATION_LANGUAGE, True)
+ def test_view_logged_out(self):
+ url = reverse('unified_translations_enabled_view')
+ self.client.logout()
+ response = self.client.get(url)
+ assert response.status_code == 302
+
+ @ddt.data(True, False)
+ def test_view(self, enabled):
+ url = reverse('unified_translations_enabled_view')
+ user = UserFactory.create()
+ assert self.client.login(username=user.username, password=TEST_PASSWORD)
+ with override_waffle_flag(ENABLE_UNIFIED_SITE_AND_TRANSLATION_LANGUAGE, enabled):
+ response = self.client.get(url)
+ assert response.json()['enabled'] == enabled
diff --git a/lms/djangoapps/courseware/toggles.py b/lms/djangoapps/courseware/toggles.py
index f9f083cad42e..4d3ba067162d 100644
--- a/lms/djangoapps/courseware/toggles.py
+++ b/lms/djangoapps/courseware/toggles.py
@@ -2,7 +2,7 @@
Toggles for courseware in-course experience.
"""
-from edx_toggles.toggles import SettingToggle, WaffleSwitch
+from edx_toggles.toggles import SettingToggle, WaffleSwitch, WaffleFlag
from openedx.core.djangoapps.waffle_utils import CourseWaffleFlag
@@ -168,6 +168,19 @@
f'{WAFFLE_FLAG_NAMESPACE}.discovery_default_language_filter', __name__
)
+# .. toggle_name: courseware.unify_site_and_translation_language
+# .. toggle_implementation: WaffleFlag
+# .. toggle_default: False
+# .. toggle_description: Update LMS to use site language for xpert unit translations and enable new header site language switcher.
+# .. toggle_use_cases: opt_in
+# .. toggle_creation_date: 2026-01-08
+# .. toggle_target_removal_date: None
+# .. toggle_warning: None.
+# .. toggle_tickets: https://github.com/edx/edx-platform/pull/81
+ENABLE_UNIFIED_SITE_AND_TRANSLATION_LANGUAGE = WaffleFlag(
+ f'{WAFFLE_FLAG_NAMESPACE}.unify_site_and_translation_language', __name__
+)
+
def course_exit_page_is_active(course_key):
return COURSEWARE_MICROFRONTEND_COURSE_EXIT_PAGE.is_enabled(course_key)
@@ -202,3 +215,10 @@ def courseware_disable_navigation_sidebar_blocks_caching(course_key=None):
Return whether the courseware.disable_navigation_sidebar_blocks_caching flag is on.
"""
return COURSEWARE_MICROFRONTEND_NAVIGATION_SIDEBAR_BLOCKS_DISABLE_CACHING.is_enabled(course_key)
+
+
+def unified_site_and_translation_language_is_enabled():
+ """
+ Return whether the courseware.unify_site_and_translation_language flag is on.
+ """
+ return ENABLE_UNIFIED_SITE_AND_TRANSLATION_LANGUAGE.is_enabled()
diff --git a/lms/djangoapps/courseware/views/views.py b/lms/djangoapps/courseware/views/views.py
index 2c89800d82b6..ccee9a0fa729 100644
--- a/lms/djangoapps/courseware/views/views.py
+++ b/lms/djangoapps/courseware/views/views.py
@@ -159,6 +159,7 @@
from ..toggles import (
COURSEWARE_OPTIMIZED_RENDER_XBLOCK,
ENABLE_COURSE_DISCOVERY_DEFAULT_LANGUAGE_FILTER,
+ unified_site_and_translation_language_is_enabled,
)
log = logging.getLogger("edx.courseware")
@@ -2402,3 +2403,12 @@ def courseware_mfe_navigation_sidebar_toggles(request, course_id=None):
# Add completion tracking status for the sidebar use while a global place for switches is put in place
"enable_completion_tracking": ENABLE_COMPLETION_TRACKING_SWITCH.is_enabled()
})
+
+
+@login_required
+@api_view(['GET'])
+def unified_site_and_translation_language_enabled(request):
+ """
+ Simple GET endpoint to expose whether the user/course has access to the unified translations feature
+ """
+ return JsonResponse({'enabled': unified_site_and_translation_language_is_enabled()})
diff --git a/lms/djangoapps/discussion/rest_api/api.py b/lms/djangoapps/discussion/rest_api/api.py
index b87852c16cfa..ac131fa4c7e0 100644
--- a/lms/djangoapps/discussion/rest_api/api.py
+++ b/lms/djangoapps/discussion/rest_api/api.py
@@ -1,17 +1,17 @@
"""
Discussion API internal interface
"""
+
from __future__ import annotations
import itertools
+import logging
import re
from collections import defaultdict
from datetime import datetime
-
from enum import Enum
from typing import Dict, Iterable, List, Literal, Optional, Set, Tuple
from urllib.parse import urlencode, urlunparse
-from pytz import UTC
from django.conf import settings
from django.contrib.auth import get_user_model
@@ -19,24 +19,27 @@
from django.db.models import Q
from django.http import Http404
from django.urls import reverse
+from django.utils.html import strip_tags
from edx_django_utils.monitoring import function_trace
from opaque_keys import InvalidKeyError
from opaque_keys.edx.locator import CourseKey
+from pytz import UTC
from rest_framework import status
from rest_framework.exceptions import PermissionDenied
from rest_framework.request import Request
from rest_framework.response import Response
-from common.djangoapps.student.roles import (
- CourseInstructorRole,
- CourseStaffRole,
-)
-
+from common.djangoapps.student.roles import CourseInstructorRole, CourseStaffRole
+from forum import api as forum_api
from lms.djangoapps.course_api.blocks.api import get_blocks
from lms.djangoapps.courseware.courses import get_course_with_access
from lms.djangoapps.courseware.exceptions import CourseAccessRedirect
from lms.djangoapps.discussion.rate_limit import is_content_creation_rate_limited
-from lms.djangoapps.discussion.toggles import ENABLE_DISCUSSIONS_MFE, ONLY_VERIFIED_USERS_CAN_POST
+from lms.djangoapps.discussion.toggles import (
+ ENABLE_DISCUSSIONS_MFE,
+ ENABLE_DISCUSSION_BAN,
+ ONLY_VERIFIED_USERS_CAN_POST,
+)
from lms.djangoapps.discussion.views import is_privileged_user
from openedx.core.djangoapps.discussions.models import (
DiscussionsConfiguration,
@@ -48,12 +51,12 @@
from openedx.core.djangoapps.django_comment_common.comment_client.comment import Comment
from openedx.core.djangoapps.django_comment_common.comment_client.course import (
get_course_commentable_counts,
- get_course_user_stats
+ get_course_user_stats,
)
from openedx.core.djangoapps.django_comment_common.comment_client.thread import Thread
from openedx.core.djangoapps.django_comment_common.comment_client.utils import (
CommentClient500Error,
- CommentClientRequestError
+ CommentClientRequestError,
)
from openedx.core.djangoapps.django_comment_common.models import (
FORUM_ROLE_ADMINISTRATOR,
@@ -61,13 +64,13 @@
FORUM_ROLE_GROUP_MODERATOR,
FORUM_ROLE_MODERATOR,
CourseDiscussionSettings,
- Role
+ Role,
)
from openedx.core.djangoapps.django_comment_common.signals import (
comment_created,
comment_deleted,
- comment_endorsed,
comment_edited,
+ comment_endorsed,
comment_flagged,
comment_voted,
thread_created,
@@ -75,11 +78,15 @@
thread_edited,
thread_flagged,
thread_followed,
+ thread_unfollowed,
thread_voted,
- thread_unfollowed
)
from openedx.core.djangoapps.user_api.accounts.api import get_account_settings
-from openedx.core.lib.exceptions import CourseNotFoundError, DiscussionNotFoundError, PageNotFoundError
+from openedx.core.lib.exceptions import (
+ CourseNotFoundError,
+ DiscussionNotFoundError,
+ PageNotFoundError,
+)
from xmodule.course_block import CourseBlock
from xmodule.modulestore import ModuleStoreEnum
from xmodule.modulestore.django import modulestore
@@ -88,21 +95,27 @@
from ..django_comment_client.base.views import (
track_comment_created_event,
track_comment_deleted_event,
+ track_discussion_reported_event,
+ track_discussion_unreported_event,
+ track_forum_search_event,
track_thread_created_event,
track_thread_deleted_event,
+ track_thread_followed_event,
track_thread_viewed_event,
track_voted_event,
- track_discussion_reported_event,
- track_discussion_unreported_event,
- track_forum_search_event, track_thread_followed_event
)
from ..django_comment_client.utils import (
get_group_id_for_user,
get_user_role_names,
has_discussion_privileges,
- is_commentable_divided
+ is_commentable_divided,
+)
+from .exceptions import (
+ CommentNotFoundError,
+ DiscussionBlackOutException,
+ DiscussionDisabledError,
+ ThreadNotFoundError,
)
-from .exceptions import CommentNotFoundError, DiscussionBlackOutException, DiscussionDisabledError, ThreadNotFoundError
from .forms import CommentActionsForm, ThreadActionsForm, UserOrdering
from .pagination import DiscussionAPIPagination
from .permissions import (
@@ -110,7 +123,7 @@
can_take_action_on_spam,
get_editable_fields,
get_initializable_comment_fields,
- get_initializable_thread_fields
+ get_initializable_thread_fields,
)
from .serializers import (
CommentSerializer,
@@ -119,20 +132,23 @@
ThreadSerializer,
TopicOrdering,
UserStatsSerializer,
- get_context
+ get_context,
)
from .utils import (
AttributeDict,
add_stats_for_users_with_no_discussion_content,
+ can_user_notify_all_learners,
create_blocks_params,
discussion_open_for_user,
+ get_captcha_site_key_by_platform,
get_usernames_for_course,
get_usernames_from_search_string,
- set_attribute,
+ is_captcha_enabled,
is_posting_allowed,
- can_user_notify_all_learners, is_captcha_enabled, get_captcha_site_key_by_platform
+ set_attribute,
)
+log = logging.getLogger(__name__)
User = get_user_model()
ThreadType = Literal["discussion", "question"]
@@ -166,11 +182,14 @@ class DiscussionEntity(Enum):
"""
Enum for different types of discussion related entities
"""
- thread = 'thread'
- comment = 'comment'
+
+ thread = "thread"
+ comment = "comment"
-def _get_course(course_key: CourseKey, user: User, check_tab: bool = True) -> CourseBlock:
+def _get_course(
+ course_key: CourseKey, user: User, check_tab: bool = True
+) -> CourseBlock:
"""
Get the course block, raising CourseNotFoundError if the course is not found or
the user cannot access forums for the course, and DiscussionDisabledError if the
@@ -188,14 +207,16 @@ def _get_course(course_key: CourseKey, user: User, check_tab: bool = True) -> Co
CourseBlock: course object
"""
try:
- course = get_course_with_access(user, 'load', course_key, check_if_enrolled=True)
+ course = get_course_with_access(
+ user, "load", course_key, check_if_enrolled=True
+ )
except (Http404, CourseAccessRedirect) as err:
# Convert 404s into CourseNotFoundErrors.
# Raise course not found if the user cannot access the course
raise CourseNotFoundError("Course not found.") from err
if check_tab:
- discussion_tab = CourseTabList.get_tab_by_type(course.tabs, 'discussion')
+ discussion_tab = CourseTabList.get_tab_by_type(course.tabs, "discussion")
if not (discussion_tab and discussion_tab.is_enabled(course, user)):
raise DiscussionDisabledError("Discussion is disabled for the course.")
@@ -216,22 +237,34 @@ def _get_thread_and_context(request, thread_id, retrieve_kwargs=None, course_id=
retrieve_kwargs["with_responses"] = False
if "mark_as_read" not in retrieve_kwargs:
retrieve_kwargs["mark_as_read"] = False
- cc_thread = Thread(id=thread_id).retrieve(course_id=course_id, **retrieve_kwargs)
+ cc_thread = Thread(id=thread_id).retrieve(
+ course_id=course_id, **retrieve_kwargs
+ )
course_key = CourseKey.from_string(cc_thread["course_id"])
course = _get_course(course_key, request.user)
context = get_context(course, request, cc_thread)
- if retrieve_kwargs.get("flagged_comments") and not context["has_moderation_privilege"]:
+ if (
+ retrieve_kwargs.get("flagged_comments")
+ and not context["has_moderation_privilege"]
+ ):
raise ValidationError("Only privileged users can request flagged comments")
course_discussion_settings = CourseDiscussionSettings.get(course_key)
if (
- not context["has_moderation_privilege"] and
- cc_thread["group_id"] and
- is_commentable_divided(course.id, cc_thread["commentable_id"], course_discussion_settings)
+ not context["has_moderation_privilege"]
+ and cc_thread["group_id"]
+ and is_commentable_divided(
+ course.id, cc_thread["commentable_id"], course_discussion_settings
+ )
):
- requester_group_id = get_group_id_for_user(request.user, course_discussion_settings)
- if requester_group_id is not None and cc_thread["group_id"] != requester_group_id:
+ requester_group_id = get_group_id_for_user(
+ request.user, course_discussion_settings
+ )
+ if (
+ requester_group_id is not None
+ and cc_thread["group_id"] != requester_group_id
+ ):
raise ThreadNotFoundError("Thread not found.")
return cc_thread, context
except CommentClientRequestError as err:
@@ -264,8 +297,8 @@ def _is_user_author_or_privileged(cc_content, context):
Boolean
"""
return (
- context["has_moderation_privilege"] or
- context["cc_requester"]["id"] == cc_content["user_id"]
+ context["has_moderation_privilege"]
+ or context["cc_requester"]["id"] == cc_content["user_id"]
)
@@ -275,11 +308,13 @@ def get_thread_list_url(request, course_key, topic_id_list=None, following=False
"""
path = reverse("thread-list")
query_list = (
- [("course_id", str(course_key))] +
- [("topic_id", topic_id) for topic_id in topic_id_list or []] +
- ([("following", following)] if following else [])
+ [("course_id", str(course_key))]
+ + [("topic_id", topic_id) for topic_id in topic_id_list or []]
+ + ([("following", following)] if following else [])
+ )
+ return request.build_absolute_uri(
+ urlunparse(("", "", path, "", urlencode(query_list), ""))
)
- return request.build_absolute_uri(urlunparse(("", "", path, "", urlencode(query_list), "")))
def get_course(request, course_key, check_tab=True):
@@ -324,23 +359,37 @@ def _format_datetime(dt):
the substitution... though really, that would probably break mobile
client parsing of the dates as well. :-P
"""
- return dt.isoformat().replace('+00:00', 'Z')
+ return dt.isoformat().replace("+00:00", "Z")
course = _get_course(course_key, request.user, check_tab=check_tab)
user_roles = get_user_role_names(request.user, course_key)
course_config = DiscussionsConfiguration.get(course_key)
EDIT_REASON_CODES = getattr(settings, "DISCUSSION_MODERATION_EDIT_REASON_CODES", {})
- CLOSE_REASON_CODES = getattr(settings, "DISCUSSION_MODERATION_CLOSE_REASON_CODES", {})
+ CLOSE_REASON_CODES = getattr(
+ settings, "DISCUSSION_MODERATION_CLOSE_REASON_CODES", {}
+ )
is_posting_enabled = is_posting_allowed(
- course_config.posting_restrictions,
- course.get_discussion_blackout_datetimes()
+ course_config.posting_restrictions, course.get_discussion_blackout_datetimes()
)
- discussion_tab = CourseTabList.get_tab_by_type(course.tabs, 'discussion')
+ discussion_tab = CourseTabList.get_tab_by_type(course.tabs, "discussion")
is_course_staff = CourseStaffRole(course_key).has_user(request.user)
is_course_admin = CourseInstructorRole(course_key).has_user(request.user)
+
+ # Check if the user is banned from discussions
+ is_user_banned_func = getattr(forum_api, 'is_user_banned', None)
+ is_user_banned = False
+ # Only check ban status if feature flag is enabled
+ if ENABLE_DISCUSSION_BAN.is_enabled(course_key) and is_user_banned_func is not None:
+ try:
+ is_user_banned = is_user_banned_func(request.user, course_key)
+ except Exception: # pylint: disable=broad-except
+ # If ban check fails, default to False
+ is_user_banned = False
+
return {
"id": str(course_key),
"is_posting_enabled": is_posting_enabled,
+ "is_user_banned": is_user_banned,
"blackouts": [
{
"start": _format_datetime(blackout["start"]),
@@ -349,7 +398,9 @@ def _format_datetime(dt):
for blackout in course.get_discussion_blackout_datetimes()
],
"thread_list_url": get_thread_list_url(request, course_key),
- "following_thread_list_url": get_thread_list_url(request, course_key, following=True),
+ "following_thread_list_url": get_thread_list_url(
+ request, course_key, following=True
+ ),
"topics_url": request.build_absolute_uri(
reverse("course_topics", kwargs={"course_id": course_key})
),
@@ -357,18 +408,23 @@ def _format_datetime(dt):
"allow_anonymous_to_peers": course.allow_anonymous_to_peers,
"user_roles": user_roles,
"has_bulk_delete_privileges": can_take_action_on_spam(request.user, course_key),
- "has_moderation_privileges": bool(user_roles & {
- FORUM_ROLE_ADMINISTRATOR,
- FORUM_ROLE_MODERATOR,
- FORUM_ROLE_COMMUNITY_TA,
- }),
+ "has_moderation_privileges": bool(
+ user_roles
+ & {
+ FORUM_ROLE_ADMINISTRATOR,
+ FORUM_ROLE_MODERATOR,
+ FORUM_ROLE_COMMUNITY_TA,
+ }
+ ),
"is_group_ta": bool(user_roles & {FORUM_ROLE_GROUP_MODERATOR}),
"is_user_admin": request.user.is_staff,
"is_course_staff": is_course_staff,
"is_course_admin": is_course_admin,
"provider": course_config.provider_type,
"enable_in_context": course_config.enable_in_context,
- "group_at_subsection": course_config.plugin_configuration.get("group_at_subsection", False),
+ "group_at_subsection": course_config.plugin_configuration.get(
+ "group_at_subsection", False
+ ),
"edit_reasons": [
{"code": reason_code, "label": label}
for (reason_code, label) in EDIT_REASON_CODES.items()
@@ -377,17 +433,24 @@ def _format_datetime(dt):
{"code": reason_code, "label": label}
for (reason_code, label) in CLOSE_REASON_CODES.items()
],
- 'show_discussions': bool(discussion_tab and discussion_tab.is_enabled(course, request.user)),
- 'is_notify_all_learners_enabled': can_user_notify_all_learners(
+ "show_discussions": bool(
+ discussion_tab and discussion_tab.is_enabled(course, request.user)
+ ),
+ "is_notify_all_learners_enabled": can_user_notify_all_learners(
user_roles, is_course_staff, is_course_admin
),
- 'captcha_settings': {
- 'enabled': is_captcha_enabled(course_key),
- 'site_key': get_captcha_site_key_by_platform('web'),
+ "captcha_settings": {
+ "enabled": is_captcha_enabled(course_key),
+ "site_key": get_captcha_site_key_by_platform("web"),
},
"is_email_verified": request.user.is_active,
- "only_verified_users_can_post": ONLY_VERIFIED_USERS_CAN_POST.is_enabled(course_key),
- "content_creation_rate_limited": is_content_creation_rate_limited(request, course_key, increment=False),
+ "only_verified_users_can_post": ONLY_VERIFIED_USERS_CAN_POST.is_enabled(
+ course_key
+ ),
+ "content_creation_rate_limited": is_content_creation_rate_limited(
+ request, course_key, increment=False
+ ),
+ "enable_discussion_ban": ENABLE_DISCUSSION_BAN.is_enabled(course_key),
}
@@ -440,7 +503,7 @@ def convert(text):
return text
def alphanum_key(key):
- return [convert(c) for c in re.split('([0-9]+)', key)]
+ return [convert(c) for c in re.split("([0-9]+)", key)]
return sorted(category_list, key=alphanum_key)
@@ -482,7 +545,7 @@ def get_non_courseware_topics(
course_key: CourseKey,
course: CourseBlock,
topic_ids: Optional[List[str]],
- thread_counts: Dict[str, Dict[str, int]]
+ thread_counts: Dict[str, Dict[str, int]],
) -> Tuple[List[Dict], Set[str]]:
"""
Returns a list of topic trees that are not linked to courseware.
@@ -506,13 +569,17 @@ def get_non_courseware_topics(
existing_topic_ids = set()
topics = list(course.discussion_topics.items())
for name, entry in topics:
- if not topic_ids or entry['id'] in topic_ids:
+ if not topic_ids or entry["id"] in topic_ids:
discussion_topic = DiscussionTopic(
- entry["id"], name, get_thread_list_url(request, course_key, [entry["id"]]),
+ entry["id"],
+ name,
+ get_thread_list_url(request, course_key, [entry["id"]]),
None,
- thread_counts.get(entry["id"])
+ thread_counts.get(entry["id"]),
+ )
+ non_courseware_topics.append(
+ DiscussionTopicSerializer(discussion_topic).data
)
- non_courseware_topics.append(DiscussionTopicSerializer(discussion_topic).data)
if topic_ids and entry["id"] in topic_ids:
existing_topic_ids.add(entry["id"])
@@ -520,7 +587,9 @@ def get_non_courseware_topics(
return non_courseware_topics, existing_topic_ids
-def get_course_topics(request: Request, course_key: CourseKey, topic_ids: Optional[Set[str]] = None):
+def get_course_topics(
+ request: Request, course_key: CourseKey, topic_ids: Optional[Set[str]] = None
+):
"""
Returns the course topic listing for the given course and user; filtered
by 'topic_ids' list if given.
@@ -544,15 +613,25 @@ def get_course_topics(request: Request, course_key: CourseKey, topic_ids: Option
courseware_topics, existing_courseware_topic_ids = get_courseware_topics(
request, course_key, course, topic_ids, thread_counts
)
- non_courseware_topics, existing_non_courseware_topic_ids = get_non_courseware_topics(
- request, course_key, course, topic_ids, thread_counts,
+ non_courseware_topics, existing_non_courseware_topic_ids = (
+ get_non_courseware_topics(
+ request,
+ course_key,
+ course,
+ topic_ids,
+ thread_counts,
+ )
)
if topic_ids:
- not_found_topic_ids = topic_ids - (existing_courseware_topic_ids | existing_non_courseware_topic_ids)
+ not_found_topic_ids = topic_ids - (
+ existing_courseware_topic_ids | existing_non_courseware_topic_ids
+ )
if not_found_topic_ids:
raise DiscussionNotFoundError(
- "Discussion not found for '{}'.".format(", ".join(str(id) for id in not_found_topic_ids))
+ "Discussion not found for '{}'.".format(
+ ", ".join(str(id) for id in not_found_topic_ids)
+ )
)
return {
@@ -567,17 +646,19 @@ def get_v2_non_courseware_topics_as_v1(request, course_key, topics):
"""
non_courseware_topics = []
for topic in topics:
- if topic.get('usage_key', '') is None:
- for key in ['usage_key', 'enabled_in_context']:
+ if topic.get("usage_key", "") is None:
+ for key in ["usage_key", "enabled_in_context"]:
topic.pop(key)
- topic.update({
- 'children': [],
- 'thread_list_url': get_thread_list_url(
- request,
- course_key,
- topic.get('id'),
- )
- })
+ topic.update(
+ {
+ "children": [],
+ "thread_list_url": get_thread_list_url(
+ request,
+ course_key,
+ topic.get("id"),
+ ),
+ }
+ )
non_courseware_topics.append(topic)
return non_courseware_topics
@@ -589,23 +670,25 @@ def get_v2_courseware_topics_as_v1(request, course_key, sequentials, topics):
courseware_topics = []
for sequential in sequentials:
children = []
- for child in sequential.get('children', []):
+ for child in sequential.get("children", []):
for topic in topics:
- if child == topic.get('usage_key'):
- topic.update({
- 'children': [],
- 'thread_list_url': get_thread_list_url(
- request,
- course_key,
- [topic.get('id')],
- )
- })
- topic.pop('enabled_in_context')
+ if child == topic.get("usage_key"):
+ topic.update(
+ {
+ "children": [],
+ "thread_list_url": get_thread_list_url(
+ request,
+ course_key,
+ [topic.get("id")],
+ ),
+ }
+ )
+ topic.pop("enabled_in_context")
children.append(AttributeDict(topic))
discussion_topic = DiscussionTopic(
None,
- sequential.get('display_name'),
+ sequential.get("display_name"),
get_thread_list_url(
request,
course_key,
@@ -618,7 +701,7 @@ def get_v2_courseware_topics_as_v1(request, course_key, sequentials, topics):
courseware_topics = [
courseware_topic
for courseware_topic in courseware_topics
- if courseware_topic.get('children', [])
+ if courseware_topic.get("children", [])
]
return courseware_topics
@@ -635,20 +718,21 @@ def get_v2_course_topics_as_v1(
blocks_params = create_blocks_params(course_usage_key, request.user)
blocks = get_blocks(
request,
- blocks_params['usage_key'],
- blocks_params['user'],
- blocks_params['depth'],
- blocks_params['nav_depth'],
- blocks_params['requested_fields'],
- blocks_params['block_counts'],
- blocks_params['student_view_data'],
- blocks_params['return_type'],
- blocks_params['block_types_filter'],
+ blocks_params["usage_key"],
+ blocks_params["user"],
+ blocks_params["depth"],
+ blocks_params["nav_depth"],
+ blocks_params["requested_fields"],
+ blocks_params["block_counts"],
+ blocks_params["student_view_data"],
+ blocks_params["return_type"],
+ blocks_params["block_types_filter"],
hide_access_denials=False,
- )['blocks']
+ )["blocks"]
- sequentials = [value for _, value in blocks.items()
- if value.get('type') == "sequential"]
+ sequentials = [
+ value for _, value in blocks.items() if value.get("type") == "sequential"
+ ]
topics = get_course_topics_v2(course_key, request.user, topic_ids)
non_courseware_topics = get_v2_non_courseware_topics_as_v1(
@@ -705,24 +789,29 @@ def get_course_topics_v2(
# Check access to the course
store = modulestore()
_get_course(course_key, user=user, check_tab=False)
- user_is_privileged = user.is_staff or user.roles.filter(
- course_id=course_key,
- name__in=[
- FORUM_ROLE_MODERATOR,
- FORUM_ROLE_COMMUNITY_TA,
- FORUM_ROLE_ADMINISTRATOR,
- ]
- ).exists()
+ user_is_privileged = (
+ user.is_staff
+ or user.roles.filter(
+ course_id=course_key,
+ name__in=[
+ FORUM_ROLE_MODERATOR,
+ FORUM_ROLE_COMMUNITY_TA,
+ FORUM_ROLE_ADMINISTRATOR,
+ ],
+ ).exists()
+ )
with store.branch_setting(ModuleStoreEnum.Branch.draft_preferred, course_key):
blocks = store.get_items(
course_key,
- qualifiers={'category': 'vertical'},
- fields=['usage_key', 'discussion_enabled', 'display_name'],
+ qualifiers={"category": "vertical"},
+ fields=["usage_key", "discussion_enabled", "display_name"],
)
accessible_vertical_keys = []
for block in blocks:
- if block.discussion_enabled and (not block.visible_to_staff_only or user_is_privileged):
+ if block.discussion_enabled and (
+ not block.visible_to_staff_only or user_is_privileged
+ ):
accessible_vertical_keys.append(block.usage_key)
accessible_vertical_keys.append(None)
@@ -732,9 +821,13 @@ def get_course_topics_v2(
)
if user_is_privileged:
- topics_query = topics_query.filter(Q(usage_key__in=accessible_vertical_keys) | Q(enabled_in_context=False))
+ topics_query = topics_query.filter(
+ Q(usage_key__in=accessible_vertical_keys) | Q(enabled_in_context=False)
+ )
else:
- topics_query = topics_query.filter(usage_key__in=accessible_vertical_keys, enabled_in_context=True)
+ topics_query = topics_query.filter(
+ usage_key__in=accessible_vertical_keys, enabled_in_context=True
+ )
if topic_ids:
topics_query = topics_query.filter(external_id__in=topic_ids)
@@ -746,11 +839,13 @@ def get_course_topics_v2(
reverse=True,
)
elif order_by == TopicOrdering.NAME:
- topics_query = topics_query.order_by('title')
+ topics_query = topics_query.order_by("title")
else:
- topics_query = topics_query.order_by('ordering')
+ topics_query = topics_query.order_by("ordering")
- topics_data = DiscussionTopicSerializerV2(topics_query, many=True, context={"thread_counts": thread_counts}).data
+ topics_data = DiscussionTopicSerializerV2(
+ topics_query, many=True, context={"thread_counts": thread_counts}
+ ).data
return [
topic_data
for topic_data in topics_data
@@ -777,7 +872,7 @@ def _get_user_profile_dict(request, usernames):
else:
username_list = []
user_profile_details = get_account_settings(request, username_list)
- return {user['username']: user for user in user_profile_details}
+ return {user["username"]: user for user in user_profile_details}
def _user_profile(user_profile):
@@ -785,11 +880,7 @@ def _user_profile(user_profile):
Returns the user profile object. For now, this just comprises the
profile_image details.
"""
- return {
- 'profile': {
- 'image': user_profile['profile_image']
- }
- }
+ return {"profile": {"image": user_profile["profile_image"]}}
def _get_users(discussion_entity_type, discussion_entity, username_profile_dict):
@@ -807,22 +898,28 @@ def _get_users(discussion_entity_type, discussion_entity, username_profile_dict)
A dict of users with username as key and user profile details as value.
"""
users = {}
- if discussion_entity['author']:
- user_profile = username_profile_dict.get(discussion_entity['author'])
+ if discussion_entity["author"]:
+ user_profile = username_profile_dict.get(discussion_entity["author"])
if user_profile:
- users[discussion_entity['author']] = _user_profile(user_profile)
+ users[discussion_entity["author"]] = _user_profile(user_profile)
if (
discussion_entity_type == DiscussionEntity.comment
- and discussion_entity['endorsed']
- and discussion_entity['endorsed_by']
+ and discussion_entity["endorsed"]
+ and discussion_entity["endorsed_by"]
):
- users[discussion_entity['endorsed_by']] = _user_profile(username_profile_dict[discussion_entity['endorsed_by']])
+ users[discussion_entity["endorsed_by"]] = _user_profile(
+ username_profile_dict[discussion_entity["endorsed_by"]]
+ )
return users
def _add_additional_response_fields(
- request, serialized_discussion_entities, usernames, discussion_entity_type, include_profile_image
+ request,
+ serialized_discussion_entities,
+ usernames,
+ discussion_entity_type,
+ include_profile_image,
):
"""
Adds additional data to serialized discussion thread/comment.
@@ -840,9 +937,13 @@ def _add_additional_response_fields(
A list of serialized discussion thread/comment with additional data if requested.
"""
if include_profile_image:
- username_profile_dict = _get_user_profile_dict(request, usernames=','.join(usernames))
+ username_profile_dict = _get_user_profile_dict(
+ request, usernames=",".join(usernames)
+ )
for discussion_entity in serialized_discussion_entities:
- discussion_entity['users'] = _get_users(discussion_entity_type, discussion_entity, username_profile_dict)
+ discussion_entity["users"] = _get_users(
+ discussion_entity_type, discussion_entity, username_profile_dict
+ )
return serialized_discussion_entities
@@ -851,10 +952,12 @@ def _include_profile_image(requested_fields):
"""
Returns True if requested_fields list has 'profile_image' entity else False
"""
- return requested_fields and 'profile_image' in requested_fields
+ return requested_fields and "profile_image" in requested_fields
-def _serialize_discussion_entities(request, context, discussion_entities, requested_fields, discussion_entity_type):
+def _serialize_discussion_entities(
+ request, context, discussion_entities, requested_fields, discussion_entity_type
+):
"""
It serializes Discussion Entity (Thread or Comment) and add additional data if requested.
@@ -885,14 +988,19 @@ def _serialize_discussion_entities(request, context, discussion_entities, reques
results.append(serialized_entity)
if include_profile_image:
- if serialized_entity['author'] and serialized_entity['author'] not in usernames:
- usernames.append(serialized_entity['author'])
if (
- 'endorsed' in serialized_entity and serialized_entity['endorsed'] and
- 'endorsed_by' in serialized_entity and
- serialized_entity['endorsed_by'] and serialized_entity['endorsed_by'] not in usernames
+ serialized_entity["author"]
+ and serialized_entity["author"] not in usernames
+ ):
+ usernames.append(serialized_entity["author"])
+ if (
+ "endorsed" in serialized_entity
+ and serialized_entity["endorsed"]
+ and "endorsed_by" in serialized_entity
+ and serialized_entity["endorsed_by"]
+ and serialized_entity["endorsed_by"] not in usernames
):
- usernames.append(serialized_entity['endorsed_by'])
+ usernames.append(serialized_entity["endorsed_by"])
results = _add_additional_response_fields(
request, results, usernames, discussion_entity_type, include_profile_image
@@ -916,6 +1024,7 @@ def get_thread_list(
order_direction: Literal["desc"] = "desc",
requested_fields: Optional[List[Literal["profile_image"]]] = None,
count_flagged: bool = None,
+ show_deleted: bool = False,
):
"""
Return the list of all discussion threads pertaining to the given course
@@ -959,20 +1068,31 @@ def get_thread_list(
CourseNotFoundError: if the requesting user does not have access to the requested course
PageNotFoundError: if page requested is beyond the last
"""
- exclusive_param_count = sum(1 for param in [topic_id_list, text_search, following] if param)
+ exclusive_param_count = sum(
+ 1 for param in [topic_id_list, text_search, following] if param
+ )
if exclusive_param_count > 1: # pragma: no cover
- raise ValueError("More than one mutually exclusive param passed to get_thread_list")
+ raise ValueError(
+ "More than one mutually exclusive param passed to get_thread_list"
+ )
- cc_map = {"last_activity_at": "activity", "comment_count": "comments", "vote_count": "votes"}
+ cc_map = {
+ "last_activity_at": "activity",
+ "comment_count": "comments",
+ "vote_count": "votes",
+ }
if order_by not in cc_map:
- raise ValidationError({
- "order_by":
- [f"Invalid value. '{order_by}' must be 'last_activity_at', 'comment_count', or 'vote_count'"]
- })
+ raise ValidationError(
+ {
+ "order_by": [
+ f"Invalid value. '{order_by}' must be 'last_activity_at', 'comment_count', or 'vote_count'"
+ ]
+ }
+ )
if order_direction != "desc":
- raise ValidationError({
- "order_direction": [f"Invalid value. '{order_direction}' must be 'desc'"]
- })
+ raise ValidationError(
+ {"order_direction": [f"Invalid value. '{order_direction}' must be 'desc'"]}
+ )
course = _get_course(course_key, request.user)
context = get_context(course, request)
@@ -984,13 +1104,21 @@ def get_thread_list(
except User.DoesNotExist:
# Raising an error for a missing user leaks the presence of a username,
# so just return an empty response.
- return DiscussionAPIPagination(request, 0, 1).get_paginated_response({
- "results": [],
- "text_search_rewrite": None,
- })
+ return DiscussionAPIPagination(request, 0, 1).get_paginated_response(
+ {
+ "results": [],
+ "text_search_rewrite": None,
+ }
+ )
if count_flagged and not context["has_moderation_privilege"]:
- raise PermissionDenied("`count_flagged` can only be set by users with moderator access or higher.")
+ raise PermissionDenied(
+ "`count_flagged` can only be set by users with moderator access or higher."
+ )
+ if show_deleted and not context["has_moderation_privilege"]:
+ raise PermissionDenied(
+ "`show_deleted` can only be set by users with moderator access or higher."
+ )
group_id = None
allowed_roles = [
@@ -1010,7 +1138,9 @@ def get_thread_list(
not context["has_moderation_privilege"]
or request.user.id in context["ta_user_ids"]
):
- group_id = get_group_id_for_user(request.user, CourseDiscussionSettings.get(course.id))
+ group_id = get_group_id_for_user(
+ request.user, CourseDiscussionSettings.get(course.id)
+ )
query_params = {
"user_id": str(request.user.id),
@@ -1023,21 +1153,24 @@ def get_thread_list(
"flagged": flagged,
"thread_type": thread_type,
"count_flagged": count_flagged,
+ "show_deleted": show_deleted,
}
if view:
if view in ["unread", "unanswered", "unresponded"]:
query_params[view] = "true"
else:
- raise ValidationError({
- "view": [f"Invalid value. '{view}' must be 'unread' or 'unanswered'"]
- })
+ raise ValidationError(
+ {"view": [f"Invalid value. '{view}' must be 'unread' or 'unanswered'"]}
+ )
if following:
paginated_results = context["cc_requester"].subscribed_threads(query_params)
else:
query_params["course_id"] = str(course.id)
- query_params["commentable_ids"] = ",".join(topic_id_list) if topic_id_list else None
+ query_params["commentable_ids"] = (
+ ",".join(topic_id_list) if topic_id_list else None
+ )
query_params["text"] = text_search
paginated_results = Thread.search(query_params)
# The comments service returns the last page of results if the requested
@@ -1047,19 +1180,25 @@ def get_thread_list(
raise PageNotFoundError("Page not found (No results on this page).")
results = _serialize_discussion_entities(
- request, context, paginated_results.collection, requested_fields, DiscussionEntity.thread
+ request,
+ context,
+ paginated_results.collection,
+ requested_fields,
+ DiscussionEntity.thread,
)
paginator = DiscussionAPIPagination(
request,
paginated_results.page,
paginated_results.num_pages,
- paginated_results.thread_count
+ paginated_results.thread_count,
+ )
+ return paginator.get_paginated_response(
+ {
+ "results": results,
+ "text_search_rewrite": paginated_results.corrected_text,
+ }
)
- return paginator.get_paginated_response({
- "results": results,
- "text_search_rewrite": paginated_results.corrected_text,
- })
def get_learner_active_thread_list(request, course_key, query_params):
@@ -1154,49 +1293,101 @@ def get_learner_active_thread_list(request, course_key, query_params):
course = _get_course(course_key, request.user)
context = get_context(course, request)
- group_id = query_params.get('group_id', None)
- user_id = query_params.get('user_id', None)
- count_flagged = query_params.get('count_flagged', None)
+ group_id = query_params.get("group_id", None)
+ user_id = query_params.get("user_id", None)
+ count_flagged = query_params.get("count_flagged", None)
+ show_deleted = query_params.get("show_deleted", False)
+ if isinstance(show_deleted, str):
+ show_deleted = show_deleted.lower() == "true"
+
if user_id is None:
- return Response({'detail': 'Invalid user id'}, status=status.HTTP_400_BAD_REQUEST)
+ return Response(
+ {"detail": "Invalid user id"}, status=status.HTTP_400_BAD_REQUEST
+ )
if count_flagged and not context["has_moderation_privilege"]:
- raise PermissionDenied("count_flagged can only be set by users with moderation roles.")
+ raise PermissionDenied(
+ "count_flagged can only be set by users with moderation roles."
+ )
if "flagged" in query_params.keys() and not context["has_moderation_privilege"]:
raise PermissionDenied("Flagged filter is only available for moderators")
+ if show_deleted and not context["has_moderation_privilege"]:
+ raise PermissionDenied(
+ "show_deleted can only be set by users with moderation roles."
+ )
if group_id is None:
comment_client_user = comment_client.User(id=user_id, course_id=course_key)
else:
- comment_client_user = comment_client.User(id=user_id, course_id=course_key, group_id=group_id)
+ comment_client_user = comment_client.User(
+ id=user_id, course_id=course_key, group_id=group_id
+ )
try:
threads, page, num_pages = comment_client_user.active_threads(query_params)
threads = set_attribute(threads, "pinned", False)
+
+ # This portion below is temporary until we migrate to forum v2
+ filtered_threads = []
+ for thread in threads:
+ try:
+ forum_thread = forum_api.get_thread(
+ thread.get("id"), course_id=str(course_key)
+ )
+ is_deleted = forum_thread.get("is_deleted", False)
+
+ if show_deleted and is_deleted:
+ thread["is_deleted"] = True
+ thread["deleted_at"] = forum_thread.get("deleted_at")
+ thread["deleted_by"] = forum_thread.get("deleted_by")
+ filtered_threads.append(thread)
+ elif not show_deleted and not is_deleted:
+ filtered_threads.append(thread)
+ except Exception as e: # pylint: disable=broad-exception-caught
+ log.warning(
+ "Failed to check thread %s deletion status: %s", thread.get("id"), e
+ )
+ if not show_deleted: # Fail safe: include thread for regular users
+ filtered_threads.append(thread)
+
results = _serialize_discussion_entities(
- request, context, threads, {'profile_image'}, DiscussionEntity.thread
+ request,
+ context,
+ filtered_threads,
+ {"profile_image"},
+ DiscussionEntity.thread,
)
paginator = DiscussionAPIPagination(
- request,
- page,
- num_pages,
- len(threads)
+ request, page, num_pages, len(filtered_threads)
+ )
+ return paginator.get_paginated_response(
+ {
+ "results": results,
+ }
)
- return paginator.get_paginated_response({
- "results": results,
- })
except CommentClient500Error:
return DiscussionAPIPagination(
request,
page_num=1,
num_pages=0,
- ).get_paginated_response({
- "results": [],
- })
+ ).get_paginated_response(
+ {
+ "results": [],
+ }
+ )
-def get_comment_list(request, thread_id, endorsed, page, page_size, flagged=False, requested_fields=None,
- merge_question_type_responses=False):
+def get_comment_list(
+ request,
+ thread_id,
+ endorsed,
+ page,
+ page_size,
+ flagged=False,
+ requested_fields=None,
+ merge_question_type_responses=False,
+ show_deleted=False,
+):
"""
Return the list of comments in the given thread.
@@ -1226,7 +1417,7 @@ def get_comment_list(request, thread_id, endorsed, page, page_size, flagged=Fals
discussion.rest_api.views.CommentViewSet for more detail.
"""
response_skip = page_size * (page - 1)
- reverse_order = request.GET.get('reverse_order', False)
+ reverse_order = request.GET.get("reverse_order", False)
from_mfe_sidebar = request.GET.get("enable_in_context_sidebar", False)
cc_thread, context = _get_thread_and_context(
request,
@@ -1239,19 +1430,23 @@ def get_comment_list(request, thread_id, endorsed, page, page_size, flagged=Fals
"response_skip": response_skip,
"response_limit": page_size,
"reverse_order": reverse_order,
- "merge_question_type_responses": merge_question_type_responses
- }
+ "merge_question_type_responses": merge_question_type_responses,
+ },
)
# Responses to discussion threads cannot be separated by endorsed, but
# responses to question threads must be separated by endorsed due to the
# existing comments service interface
if cc_thread["thread_type"] == "question" and not merge_question_type_responses:
if endorsed is None: # lint-amnesty, pylint: disable=no-else-raise
- raise ValidationError({"endorsed": ["This field is required for question threads."]})
+ raise ValidationError(
+ {"endorsed": ["This field is required for question threads."]}
+ )
elif endorsed:
# CS does not apply resp_skip and resp_limit to endorsed responses
# of a question post
- responses = cc_thread["endorsed_responses"][response_skip:(response_skip + page_size)]
+ responses = cc_thread["endorsed_responses"][
+ response_skip: (response_skip + page_size)
+ ]
resp_total = len(cc_thread["endorsed_responses"])
else:
responses = cc_thread["non_endorsed_responses"]
@@ -1260,7 +1455,11 @@ def get_comment_list(request, thread_id, endorsed, page, page_size, flagged=Fals
if not merge_question_type_responses:
if endorsed is not None:
raise ValidationError(
- {"endorsed": ["This field may not be specified for discussion threads."]}
+ {
+ "endorsed": [
+ "This field may not be specified for discussion threads."
+ ]
+ }
)
responses = cc_thread["children"]
resp_total = cc_thread["resp_total"]
@@ -1272,9 +1471,21 @@ def get_comment_list(request, thread_id, endorsed, page, page_size, flagged=Fals
raise PageNotFoundError("Page not found (No results on this page).")
num_pages = (resp_total + page_size - 1) // page_size if resp_total else 1
- results = _serialize_discussion_entities(request, context, responses, requested_fields, DiscussionEntity.comment)
+ if not show_deleted:
+ responses = [
+ response for response in responses if not response.get("is_deleted", False)
+ ]
+ else:
+ if not context["has_moderation_privilege"]:
+ raise PermissionDenied(
+ "`show_deleted` can only be set by users with moderation roles."
+ )
+
+ results = _serialize_discussion_entities(
+ request, context, responses, requested_fields, DiscussionEntity.comment
+ )
- paginator = DiscussionAPIPagination(request, page, num_pages, resp_total)
+ paginator = DiscussionAPIPagination(request, page, num_pages, len(responses))
track_thread_viewed_event(request, context["course"], cc_thread, from_mfe_sidebar)
return paginator.get_paginated_response(results)
@@ -1292,7 +1503,9 @@ def _check_fields(allowed_fields, data, message):
ValidationError if the given data contains a key that is not in
allowed_fields
"""
- non_allowed_fields = {field: [message] for field in data.keys() if field not in allowed_fields}
+ non_allowed_fields = {
+ field: [message] for field in data.keys() if field not in allowed_fields
+ }
if non_allowed_fields:
raise ValidationError(non_allowed_fields)
@@ -1314,7 +1527,7 @@ def _check_initializable_thread_fields(data, context):
_check_fields(
get_initializable_thread_fields(context),
data,
- "This field is not initializable."
+ "This field is not initializable.",
)
@@ -1335,7 +1548,7 @@ def _check_initializable_comment_fields(data, context):
_check_fields(
get_initializable_comment_fields(context),
data,
- "This field is not initializable."
+ "This field is not initializable.",
)
@@ -1345,28 +1558,40 @@ def _check_editable_fields(cc_content, data, context):
editable by the requesting user
"""
_check_fields(
- get_editable_fields(cc_content, context),
- data,
- "This field is not editable."
+ get_editable_fields(cc_content, context), data, "This field is not editable."
)
-def _do_extra_actions(api_content, cc_content, request_fields, actions_form, context, request):
+def _do_extra_actions(
+ api_content, cc_content, request_fields, actions_form, context, request
+):
"""
Perform any necessary additional actions related to content creation or
update that require a separate comments service request.
"""
for field, form_value in actions_form.cleaned_data.items():
- if field in request_fields and field in api_content and form_value != api_content[field]:
+ if (
+ field in request_fields
+ and field in api_content
+ and form_value != api_content[field]
+ ):
api_content[field] = form_value
if field == "following":
- _handle_following_field(form_value, context["cc_requester"], cc_content, request)
+ _handle_following_field(
+ form_value, context["cc_requester"], cc_content, request
+ )
elif field == "abuse_flagged":
- _handle_abuse_flagged_field(form_value, context["cc_requester"], cc_content, request)
+ _handle_abuse_flagged_field(
+ form_value, context["cc_requester"], cc_content, request
+ )
elif field == "voted":
- _handle_voted_field(form_value, cc_content, api_content, request, context)
+ _handle_voted_field(
+ form_value, cc_content, api_content, request, context
+ )
elif field == "read":
- _handle_read_field(api_content, form_value, context["cc_requester"], cc_content)
+ _handle_read_field(
+ api_content, form_value, context["cc_requester"], cc_content
+ )
elif field == "pinned":
_handle_pinned_field(form_value, cc_content, context["cc_requester"])
else:
@@ -1376,7 +1601,7 @@ def _do_extra_actions(api_content, cc_content, request_fields, actions_form, con
def _handle_following_field(form_value, user, cc_content, request):
"""follow/unfollow thread for the user"""
course_key = CourseKey.from_string(cc_content.course_id)
- course = get_course_with_access(request.user, 'load', course_key)
+ course = get_course_with_access(request.user, "load", course_key)
if form_value:
user.follow(cc_content)
else:
@@ -1389,15 +1614,19 @@ def _handle_following_field(form_value, user, cc_content, request):
def _handle_abuse_flagged_field(form_value, user, cc_content, request):
"""mark or unmark thread/comment as abused"""
course_key = CourseKey.from_string(cc_content.course_id)
- course = get_course_with_access(request.user, 'load', course_key)
+ course = get_course_with_access(request.user, "load", course_key)
if form_value:
cc_content.flagAbuse(user, cc_content)
track_discussion_reported_event(request, course, cc_content)
if ENABLE_DISCUSSIONS_MFE.is_enabled(course_key):
- if cc_content.type == 'thread':
- thread_flagged.send(sender='flag_abuse_for_thread', user=user, post=cc_content)
+ if cc_content.type == "thread":
+ thread_flagged.send(
+ sender="flag_abuse_for_thread", user=user, post=cc_content
+ )
else:
- comment_flagged.send(sender='flag_abuse_for_comment', user=user, post=cc_content)
+ comment_flagged.send(
+ sender="flag_abuse_for_comment", user=user, post=cc_content
+ )
else:
remove_all = bool(is_privileged_user(course_key, User.objects.get(id=user.id)))
cc_content.unFlagAbuse(user, cc_content, remove_all)
@@ -1406,7 +1635,7 @@ def _handle_abuse_flagged_field(form_value, user, cc_content, request):
def _handle_voted_field(form_value, cc_content, api_content, request, context):
"""vote or undo vote on thread/comment"""
- signal = thread_voted if cc_content.type == 'thread' else comment_voted
+ signal = thread_voted if cc_content.type == "thread" else comment_voted
signal.send(sender=None, user=context["request"].user, post=cc_content)
if form_value:
context["cc_requester"].vote(cc_content, "up")
@@ -1415,7 +1644,11 @@ def _handle_voted_field(form_value, cc_content, api_content, request, context):
context["cc_requester"].unvote(cc_content)
api_content["vote_count"] -= 1
track_voted_event(
- request, context["course"], cc_content, vote_value="up", undo_vote=not form_value
+ request,
+ context["course"],
+ cc_content,
+ vote_value="up",
+ undo_vote=not form_value,
)
@@ -1423,7 +1656,7 @@ def _handle_read_field(api_content, form_value, user, cc_content):
"""
Marks thread as read for the user
"""
- if form_value and not cc_content['read']:
+ if form_value and not cc_content["read"]:
user.read(cc_content)
# When a thread is marked as read, all of its responses and comments
# are also marked as read.
@@ -1485,29 +1718,56 @@ def create_thread(request, thread_data):
if not discussion_open_for_user(course, user):
raise DiscussionBlackOutException
+ # Check if user is banned from discussions
+ is_user_banned_func = getattr(forum_api, 'is_user_banned', None)
+ user_banned = False
+ if ENABLE_DISCUSSION_BAN.is_enabled(course_key) and is_user_banned_func:
+ try:
+ user_banned = is_user_banned_func(user, course_key)
+ except (CommentClientRequestError, CommentClient500Error) as exc:
+ log.warning(
+ "Error while checking discussion ban status for user %s in course %s: %s",
+ getattr(user, "id", None),
+ course_key,
+ exc,
+ )
+ if user_banned:
+ raise PermissionDenied("You are banned from posting in this course's discussions.")
+
notify_all_learners = thread_data.pop("notify_all_learners", False)
context = get_context(course, request)
_check_initializable_thread_fields(thread_data, context)
discussion_settings = CourseDiscussionSettings.get(course_key)
- if (
- "group_id" not in thread_data and
- is_commentable_divided(course_key, thread_data.get("topic_id"), discussion_settings)
+ if "group_id" not in thread_data and is_commentable_divided(
+ course_key, thread_data.get("topic_id"), discussion_settings
):
thread_data = thread_data.copy()
thread_data["group_id"] = get_group_id_for_user(user, discussion_settings)
serializer = ThreadSerializer(data=thread_data, context=context)
actions_form = ThreadActionsForm(thread_data)
if not (serializer.is_valid() and actions_form.is_valid()):
- raise ValidationError(dict(list(serializer.errors.items()) + list(actions_form.errors.items())))
+ raise ValidationError(
+ dict(list(serializer.errors.items()) + list(actions_form.errors.items()))
+ )
serializer.save()
cc_thread = serializer.instance
- thread_created.send(sender=None, user=user, post=cc_thread, notify_all_learners=notify_all_learners)
+ thread_created.send(
+ sender=None, user=user, post=cc_thread, notify_all_learners=notify_all_learners
+ )
api_thread = serializer.data
- _do_extra_actions(api_thread, cc_thread, list(thread_data.keys()), actions_form, context, request)
+ _do_extra_actions(
+ api_thread, cc_thread, list(thread_data.keys()), actions_form, context, request
+ )
- track_thread_created_event(request, course, cc_thread, actions_form.cleaned_data["following"],
- from_mfe_sidebar, notify_all_learners)
+ track_thread_created_event(
+ request,
+ course,
+ cc_thread,
+ actions_form.cleaned_data["following"],
+ from_mfe_sidebar,
+ notify_all_learners,
+ )
return api_thread
@@ -1538,6 +1798,22 @@ def create_comment(request, comment_data):
if not discussion_open_for_user(course, request.user):
raise DiscussionBlackOutException
+ # Check if user is banned from discussions
+ is_user_banned_func = getattr(forum_api, 'is_user_banned', None)
+ user_banned = False
+ if ENABLE_DISCUSSION_BAN.is_enabled(course.id) and is_user_banned_func:
+ try:
+ user_banned = is_user_banned_func(request.user, course.id)
+ except (CommentClientRequestError, CommentClient500Error) as exc:
+ log.warning(
+ "Error while checking discussion ban status for user %s in course %s: %s",
+ getattr(request.user, "id", None),
+ course.id,
+ exc,
+ )
+ if user_banned:
+ raise PermissionDenied("You are banned from posting in this course's discussions.")
+
# if a thread is closed; no new comments could be made to it
if cc_thread["closed"]:
raise PermissionDenied
@@ -1546,15 +1822,30 @@ def create_comment(request, comment_data):
serializer = CommentSerializer(data=comment_data, context=context)
actions_form = CommentActionsForm(comment_data)
if not (serializer.is_valid() and actions_form.is_valid()):
- raise ValidationError(dict(list(serializer.errors.items()) + list(actions_form.errors.items())))
+ raise ValidationError(
+ dict(list(serializer.errors.items()) + list(actions_form.errors.items()))
+ )
context["cc_requester"].follow(cc_thread)
serializer.save()
cc_comment = serializer.instance
comment_created.send(sender=None, user=request.user, post=cc_comment)
api_comment = serializer.data
- _do_extra_actions(api_comment, cc_comment, list(comment_data.keys()), actions_form, context, request)
- track_comment_created_event(request, course, cc_comment, cc_thread["commentable_id"], followed=False,
- from_mfe_sidebar=from_mfe_sidebar)
+ _do_extra_actions(
+ api_comment,
+ cc_comment,
+ list(comment_data.keys()),
+ actions_form,
+ context,
+ request,
+ )
+ track_comment_created_event(
+ request,
+ course,
+ cc_comment,
+ cc_thread["commentable_id"],
+ followed=False,
+ from_mfe_sidebar=from_mfe_sidebar,
+ )
return api_comment
@@ -1576,24 +1867,32 @@ def update_thread(request, thread_id, update_data):
The updated thread; see discussion.rest_api.views.ThreadViewSet for more
detail.
"""
- cc_thread, context = _get_thread_and_context(request, thread_id, retrieve_kwargs={"with_responses": True})
+ cc_thread, context = _get_thread_and_context(
+ request, thread_id, retrieve_kwargs={"with_responses": True}
+ )
_check_editable_fields(cc_thread, update_data, context)
- serializer = ThreadSerializer(cc_thread, data=update_data, partial=True, context=context)
+ serializer = ThreadSerializer(
+ cc_thread, data=update_data, partial=True, context=context
+ )
actions_form = ThreadActionsForm(update_data)
if not (serializer.is_valid() and actions_form.is_valid()):
- raise ValidationError(dict(list(serializer.errors.items()) + list(actions_form.errors.items())))
+ raise ValidationError(
+ dict(list(serializer.errors.items()) + list(actions_form.errors.items()))
+ )
# Only save thread object if some of the edited fields are in the thread data, not extra actions
if set(update_data) - set(actions_form.fields):
serializer.save()
# signal to update Teams when a user edits a thread
thread_edited.send(sender=None, user=request.user, post=cc_thread)
api_thread = serializer.data
- _do_extra_actions(api_thread, cc_thread, list(update_data.keys()), actions_form, context, request)
+ _do_extra_actions(
+ api_thread, cc_thread, list(update_data.keys()), actions_form, context, request
+ )
# always return read as True (and therefore unread_comment_count=0) as reasonably
# accurate shortcut, rather than adding additional processing.
- api_thread['read'] = True
- api_thread['unread_comment_count'] = 0
+ api_thread["read"] = True
+ api_thread["unread_comment_count"] = 0
return api_thread
@@ -1628,16 +1927,27 @@ def update_comment(request, comment_id, update_data):
"""
cc_comment, context = _get_comment_and_context(request, comment_id)
_check_editable_fields(cc_comment, update_data, context)
- serializer = CommentSerializer(cc_comment, data=update_data, partial=True, context=context)
+ serializer = CommentSerializer(
+ cc_comment, data=update_data, partial=True, context=context
+ )
actions_form = CommentActionsForm(update_data)
if not (serializer.is_valid() and actions_form.is_valid()):
- raise ValidationError(dict(list(serializer.errors.items()) + list(actions_form.errors.items())))
+ raise ValidationError(
+ dict(list(serializer.errors.items()) + list(actions_form.errors.items()))
+ )
# Only save comment object if some of the edited fields are in the comment data, not extra actions
if set(update_data) - set(actions_form.fields):
serializer.save()
comment_edited.send(sender=None, user=request.user, post=cc_comment)
api_comment = serializer.data
- _do_extra_actions(api_comment, cc_comment, list(update_data.keys()), actions_form, context, request)
+ _do_extra_actions(
+ api_comment,
+ cc_comment,
+ list(update_data.keys()),
+ actions_form,
+ context,
+ request,
+ )
_handle_comment_signals(update_data, cc_comment, request.user)
return api_comment
@@ -1671,7 +1981,9 @@ def get_thread(request, thread_id, requested_fields=None, course_id=None):
)
if course_id and course_id != cc_thread.course_id:
raise ThreadNotFoundError("Thread not found.")
- return _serialize_discussion_entities(request, context, [cc_thread], requested_fields, DiscussionEntity.thread)[0]
+ return _serialize_discussion_entities(
+ request, context, [cc_thread], requested_fields, DiscussionEntity.thread
+ )[0]
def get_response_comments(request, comment_id, page, page_size, requested_fields=None):
@@ -1699,7 +2011,10 @@ def get_response_comments(request, comment_id, page, page_size, requested_fields
"""
try:
cc_comment = Comment(id=comment_id).retrieve()
- reverse_order = request.GET.get('reverse_order', False)
+ reverse_order = request.GET.get("reverse_order", False)
+ show_deleted = request.GET.get("show_deleted", False)
+ show_deleted = show_deleted in ["true", "True", True]
+
cc_thread, context = _get_thread_and_context(
request,
cc_comment["thread_id"],
@@ -1707,10 +2022,13 @@ def get_response_comments(request, comment_id, page, page_size, requested_fields
"with_responses": True,
"recursive": True,
"reverse_order": reverse_order,
- }
+ "show_deleted": show_deleted,
+ },
)
if cc_thread["thread_type"] == "question":
- thread_responses = itertools.chain(cc_thread["endorsed_responses"], cc_thread["non_endorsed_responses"])
+ thread_responses = itertools.chain(
+ cc_thread["endorsed_responses"], cc_thread["non_endorsed_responses"]
+ )
else:
thread_responses = cc_thread["children"]
response_comments = []
@@ -1720,16 +2038,35 @@ def get_response_comments(request, comment_id, page, page_size, requested_fields
break
response_skip = page_size * (page - 1)
- paged_response_comments = response_comments[response_skip:(response_skip + page_size)]
+ paged_response_comments = response_comments[
+ response_skip: (response_skip + page_size)
+ ]
if not paged_response_comments and page != 1:
raise PageNotFoundError("Page not found (No results on this page).")
+ if not show_deleted:
+ paged_response_comments = [
+ response
+ for response in paged_response_comments
+ if not response.get("is_deleted", False)
+ ]
+ else:
+ if not context["has_moderation_privilege"]:
+ raise PermissionDenied(
+ "`show_deleted` can only be set by users with moderation roles."
+ )
results = _serialize_discussion_entities(
- request, context, paged_response_comments, requested_fields, DiscussionEntity.comment
+ request,
+ context,
+ paged_response_comments,
+ requested_fields,
+ DiscussionEntity.comment,
)
- comments_count = len(response_comments)
- num_pages = (comments_count + page_size - 1) // page_size if comments_count else 1
+ comments_count = len(paged_response_comments)
+ num_pages = (
+ (comments_count + page_size - 1) // page_size if comments_count else 1
+ )
paginator = DiscussionAPIPagination(request, page, num_pages, comments_count)
return paginator.get_paginated_response(results)
except CommentClientRequestError as err:
@@ -1773,16 +2110,20 @@ def get_user_comments(
context = get_context(course, request)
if flagged and not context["has_moderation_privilege"]:
- raise ValidationError("Only privileged users can filter comments by flagged status")
+ raise ValidationError(
+ "Only privileged users can filter comments by flagged status"
+ )
try:
- response = Comment.retrieve_all({
- 'user_id': author.id,
- 'course_id': str(course_key),
- 'flagged': flagged,
- 'page': page,
- 'per_page': page_size,
- })
+ response = Comment.retrieve_all(
+ {
+ "user_id": author.id,
+ "course_id": str(course_key),
+ "flagged": flagged,
+ "page": page,
+ "per_page": page_size,
+ }
+ )
except CommentClientRequestError as err:
raise CommentNotFoundError("Comment not found") from err
@@ -1822,7 +2163,7 @@ def delete_thread(request, thread_id):
"""
cc_thread, context = _get_thread_and_context(request, thread_id)
if can_delete(cc_thread, context):
- cc_thread.delete()
+ cc_thread.delete(deleted_by=str(request.user.id))
thread_deleted.send(sender=None, user=request.user, post=cc_thread)
track_thread_deleted_event(request, context["course"], cc_thread)
else:
@@ -1847,7 +2188,7 @@ def delete_comment(request, comment_id):
"""
cc_comment, context = _get_comment_and_context(request, comment_id)
if can_delete(cc_comment, context):
- cc_comment.delete()
+ cc_comment.delete(deleted_by=str(request.user.id))
comment_deleted.send(sender=None, user=request.user, post=cc_comment)
track_comment_deleted_event(request, context["course"], cc_comment)
else:
@@ -1879,7 +2220,10 @@ def get_course_discussion_user_stats(
"""
course_key = CourseKey.from_string(course_key_str)
- is_privileged = has_discussion_privileges(user=request.user, course_id=course_key) or request.user.is_staff
+ is_privileged = (
+ has_discussion_privileges(user=request.user, course_id=course_key)
+ or request.user.is_staff
+ )
if is_privileged:
order_by = order_by or UserOrdering.BY_FLAGS
else:
@@ -1888,33 +2232,65 @@ def get_course_discussion_user_stats(
raise ValidationError({"order_by": "Invalid value"})
params = {
- 'sort_key': str(order_by),
- 'page': page,
- 'per_page': page_size,
+ "sort_key": str(order_by),
+ "page": page,
+ "per_page": page_size,
}
comma_separated_usernames = matched_users_count = matched_users_pages = None
if username_search_string:
- comma_separated_usernames, matched_users_count, matched_users_pages = get_usernames_from_search_string(
- course_key, username_search_string, page, page_size
+ comma_separated_usernames, matched_users_count, matched_users_pages = (
+ get_usernames_from_search_string(
+ course_key, username_search_string, page, page_size
+ )
)
search_event_data = {
- 'query': username_search_string,
- 'search_type': 'Learner',
- 'page': params.get('page'),
- 'sort_key': params.get('sort_key'),
- 'total_results': matched_users_count,
+ "query": username_search_string,
+ "search_type": "Learner",
+ "page": params.get("page"),
+ "sort_key": params.get("sort_key"),
+ "total_results": matched_users_count,
}
course = _get_course(course_key, request.user)
track_forum_search_event(request, course, search_event_data)
+
if not comma_separated_usernames:
- return DiscussionAPIPagination(request, 0, 1).get_paginated_response({
- "results": [],
- })
+ return DiscussionAPIPagination(request, 0, 1).get_paginated_response(
+ {
+ "results": [],
+ }
+ )
- params['usernames'] = comma_separated_usernames
+ params["usernames"] = comma_separated_usernames
course_stats_response = get_course_user_stats(course_key, params)
+ # Exclude banned users from the learners list
+ # Get all active bans for this course using forum API
+ get_banned_usernames = getattr(forum_api, 'get_banned_usernames', None)
+ banned_usernames = []
+ # Only filter banned users if feature flag is enabled
+ if ENABLE_DISCUSSION_BAN.is_enabled(course_key) and get_banned_usernames is not None:
+ try:
+ banned_usernames = get_banned_usernames(
+ course_id=course_key,
+ org_key=course_key.org
+ )
+ except Exception: # pylint: disable=broad-except
+ log.exception(
+ "Error retrieving banned usernames for course %s; returning unfiltered discussion stats.",
+ course_key,
+ )
+ banned_usernames = []
+
+ # Filter out banned users from the stats
+ if banned_usernames:
+ course_stats_response["user_stats"] = [
+ stats for stats in course_stats_response["user_stats"]
+ if stats.get('username') not in banned_usernames
+ ]
+ # Update count to reflect filtered results
+ course_stats_response["count"] = len(course_stats_response["user_stats"])
+
if comma_separated_usernames:
updated_course_stats = add_stats_for_users_with_no_discussion_content(
course_stats_response["user_stats"],
@@ -1931,71 +2307,502 @@ def get_course_discussion_user_stats(
paginator = DiscussionAPIPagination(
request,
course_stats_response["page"],
- matched_users_pages if username_search_string else course_stats_response["num_pages"],
- matched_users_count if username_search_string else course_stats_response["count"],
+ (
+ matched_users_pages
+ if username_search_string
+ else course_stats_response["num_pages"]
+ ),
+ (
+ matched_users_count
+ if username_search_string
+ else course_stats_response["count"]
+ ),
+ )
+ return paginator.get_paginated_response(
+ {
+ "results": serializer.data,
+ }
)
- return paginator.get_paginated_response({
- "results": serializer.data,
- })
def get_users_without_stats(
- username_search_string,
- course_key,
- page_number,
- page_size,
- request,
- is_privileged
+ username_search_string, course_key, page_number, page_size, request, is_privileged
):
"""
This return users with no user stats.
This function will be deprecated when this ticket DOS-3414 is resolved
"""
if username_search_string:
- comma_separated_usernames, matched_users_count, matched_users_pages = get_usernames_from_search_string(
- course_key, username_search_string, page_number, page_size
+ comma_separated_usernames, matched_users_count, matched_users_pages = (
+ get_usernames_from_search_string(
+ course_key, username_search_string, page_number, page_size
+ )
)
if not comma_separated_usernames:
- return DiscussionAPIPagination(request, 0, 1).get_paginated_response({
- "results": [],
- })
+ return DiscussionAPIPagination(request, 0, 1).get_paginated_response(
+ {
+ "results": [],
+ }
+ )
else:
- comma_separated_usernames, matched_users_count, matched_users_pages = get_usernames_for_course(
- course_key, page_number, page_size
+ comma_separated_usernames, matched_users_count, matched_users_pages = (
+ get_usernames_for_course(course_key, page_number, page_size)
)
if comma_separated_usernames:
- updated_course_stats = add_stats_for_users_with_null_values([], comma_separated_usernames)
+ updated_course_stats = add_stats_for_users_with_null_values(
+ [], comma_separated_usernames
+ )
- serializer = UserStatsSerializer(updated_course_stats, context={"is_privileged": is_privileged}, many=True)
+ serializer = UserStatsSerializer(
+ updated_course_stats, context={"is_privileged": is_privileged}, many=True
+ )
paginator = DiscussionAPIPagination(
request,
page_number,
matched_users_pages,
matched_users_count,
)
- return paginator.get_paginated_response({
- "results": serializer.data,
- })
+ return paginator.get_paginated_response(
+ {
+ "results": serializer.data,
+ }
+ )
def add_stats_for_users_with_null_values(course_stats, users_in_course):
"""
Update users stats for users with no discussion stats available in course
"""
- users_returned_from_api = [user['username'] for user in course_stats]
- user_list = users_in_course.split(',')
+ users_returned_from_api = [user["username"] for user in course_stats]
+ user_list = users_in_course.split(",")
users_with_no_discussion_content = set(user_list) ^ set(users_returned_from_api)
updated_course_stats = course_stats
for user in users_with_no_discussion_content:
- updated_course_stats.append({
- 'username': user,
- 'threads': None,
- 'replies': None,
- 'responses': None,
- 'active_flags': None,
- 'inactive_flags': None,
- })
- updated_course_stats = sorted(updated_course_stats, key=lambda d: len(d['username']))
+ updated_course_stats.append(
+ {
+ "username": user,
+ "threads": None,
+ "replies": None,
+ "responses": None,
+ "active_flags": None,
+ "inactive_flags": None,
+ }
+ )
+ updated_course_stats = sorted(
+ updated_course_stats, key=lambda d: len(d["username"])
+ )
return updated_course_stats
+
+
+def _get_user_label_function(course_staff_user_ids, moderator_user_ids, ta_user_ids):
+ """
+ Create and return a function that determines user labels based on role.
+
+ Args:
+ course_staff_user_ids: List of user IDs for course staff
+ moderator_user_ids: List of user IDs for moderators
+ ta_user_ids: List of user IDs for TAs
+
+ Returns:
+ A function that takes a user_id and returns the appropriate label or None
+ """
+
+ def get_user_label(user_id):
+ """Get role label for a user ID."""
+ try:
+ user_id_int = int(user_id)
+ if user_id_int in course_staff_user_ids:
+ return "Staff"
+ elif user_id_int in moderator_user_ids:
+ return "Moderator"
+ elif user_id_int in ta_user_ids:
+ return "Community TA"
+ except (ValueError, TypeError):
+ # If user_id has any issues, there's no label to return
+ pass
+ return None
+
+ return get_user_label
+
+
+def _process_deleted_thread(thread_data, get_user_label_fn, usernames_set):
+ """
+ Process a single deleted thread into the standardized content item format.
+
+ Args:
+ thread_data: Raw thread data from forum API
+ get_user_label_fn: Function to get user labels by user ID
+ usernames_set: Set to collect usernames for profile image fetch (modified in-place)
+
+ Returns:
+ dict: Formatted content item for the thread
+ """
+ author_username = thread_data.get("author_username", "") or None
+ author_id = thread_data.get("author_id", "")
+
+ # If author_username is missing or empty, try to get it from author_id
+ if not author_username and author_id:
+ try:
+ author_user = User.objects.get(id=int(author_id))
+ author_username = author_user.username
+ except (User.DoesNotExist, ValueError):
+ # If user not found or invalid ID, use placeholder
+ author_username = None
+
+ deleted_by_id = thread_data.get("deleted_by")
+ deleted_by_username = None
+
+ # Get deleted_by username
+ if deleted_by_id:
+ try:
+ deleted_user = User.objects.get(id=int(deleted_by_id))
+ deleted_by_username = deleted_user.username
+ usernames_set.add(deleted_by_username)
+ except (User.DoesNotExist, ValueError):
+ # If user not found or invalid ID, skip setting deleted fields
+ pass
+
+ if author_username:
+ usernames_set.add(author_username)
+
+ # Strip HTML tags from preview
+ body_text = thread_data.get("body", "")
+ preview_text = strip_tags(body_text)[:100] if body_text else ""
+
+ thread_id = thread_data.get("_id", thread_data.get("id"))
+
+ # Calculate vote information
+ votes = thread_data.get("votes", {})
+ vote_count = votes.get("up_count", 0) if isinstance(votes, dict) else thread_data.get("vote_count", 0)
+
+ # Get abuse flaggers
+ abuse_flaggers = thread_data.get("abuse_flaggers", [])
+ abuse_flagged_count = len(abuse_flaggers) if abuse_flaggers else None
+
+ return {
+ "id": str(thread_id) + "-thread",
+ "type": "thread",
+ "title": thread_data.get("title", ""),
+ "raw_body": body_text,
+ "rendered_body": body_text, # For deleted content, just use raw body
+ "preview_body": preview_text,
+ "course_id": thread_data.get("course_id", ""),
+ "author": author_username,
+ "author_id": thread_data.get("author_id", ""),
+ "author_label": get_user_label_fn(thread_data.get("author_id")),
+ "topic_id": thread_data.get("commentable_id", ""),
+ "commentable_id": thread_data.get("commentable_id", ""),
+ "group_id": thread_data.get("group_id"),
+ "group_name": None, # Will be populated by API layer if needed
+ "created_at": thread_data.get("created_at"),
+ "updated_at": thread_data.get("updated_at"),
+ "thread_type": thread_data.get("thread_type", "discussion"),
+ "anonymous": thread_data.get("anonymous", False),
+ "anonymous_to_peers": thread_data.get("anonymous_to_peers", False),
+ "pinned": thread_data.get("pinned", False),
+ "closed": thread_data.get("closed", False),
+ "following": False, # Deleted content is not followable
+ "abuse_flagged": len(abuse_flaggers) > 0 if abuse_flaggers else False,
+ "abuse_flagged_count": abuse_flagged_count,
+ "voted": False, # Cannot vote on deleted content
+ "vote_count": vote_count,
+ "comment_count": thread_data.get("comment_count", 0),
+ "unread_comment_count": 0, # Deleted content has no unread count
+ "comment_list_url": None,
+ "endorsed_comment_list_url": None,
+ "non_endorsed_comment_list_url": None,
+ "read": True, # Treat deleted content as read
+ "has_endorsed": thread_data.get("endorsed", False),
+ "editable_fields": [], # Deleted content is not editable
+ "can_delete": False, # Already deleted
+ "is_deleted": True,
+ "deleted_at": thread_data.get("deleted_at"),
+ "deleted_by": deleted_by_username,
+ "deleted_by_label": get_user_label_fn(deleted_by_id) if deleted_by_id else None,
+ "close_reason_code": thread_data.get("close_reason_code"),
+ "close_reason": None,
+ "closed_by": thread_data.get("closed_by"),
+ "closed_by_label": None,
+ }
+
+
+def _process_deleted_comment(comment_data, get_user_label_fn, usernames_set):
+ """
+ Process a single deleted comment into the standardized content item format.
+
+ Args:
+ comment_data: Raw comment data from forum API
+ get_user_label_fn: Function to get user labels by user ID
+ usernames_set: Set to collect usernames for profile image fetch (modified in-place)
+
+ Returns:
+ dict: Formatted content item for the comment
+ """
+ author_username = comment_data.get("author_username", "") or None
+ author_id = comment_data.get("author_id", "")
+
+ # If author_username is missing or empty, try to get it from author_id
+ if not author_username and author_id:
+ try:
+ author_user = User.objects.get(id=int(author_id))
+ author_username = author_user.username
+ except (User.DoesNotExist, ValueError):
+ # If user not found or invalid ID, use placeholder
+ author_username = None
+
+ deleted_by_id = comment_data.get("deleted_by")
+ deleted_by_username = None
+
+ # Get deleted_by username
+ if deleted_by_id:
+ try:
+ deleted_user = User.objects.get(id=int(deleted_by_id))
+ deleted_by_username = deleted_user.username
+ usernames_set.add(deleted_by_username)
+ except (User.DoesNotExist, ValueError):
+ # If user not found or invalid ID, skip setting deleted fields
+ pass
+
+ if author_username:
+ usernames_set.add(author_username)
+
+ # Determine if this is a response (depth=0) or comment (depth>0)
+ depth = comment_data.get("depth", 0)
+ comment_type = "response" if depth == 0 else "comment"
+
+ # Get parent thread title for context
+ thread_id = comment_data.get("comment_thread_id", "")
+ thread_title = ""
+ if thread_id:
+ try:
+ parent_thread = Thread(id=thread_id).retrieve()
+ thread_title = parent_thread.get("title", "")
+ except Exception: # pylint: disable=broad-exception-caught
+ pass
+
+ # Strip HTML tags from preview
+ body_text = comment_data.get("body", "")
+ preview_text = strip_tags(body_text)[:100] if body_text else ""
+
+ comment_id = comment_data.get("_id", comment_data.get("id"))
+
+ # Calculate vote information
+ votes = comment_data.get("votes", {})
+ vote_count = votes.get("up_count", 0) if isinstance(votes, dict) else comment_data.get("vote_count", 0)
+
+ # Get abuse flaggers
+ abuse_flaggers = comment_data.get("abuse_flaggers", [])
+ abuse_flagged_count = len(abuse_flaggers) if abuse_flaggers else None
+
+ return {
+ "id": str(comment_id) + "-comment",
+ "type": comment_type,
+ "raw_body": body_text,
+ "rendered_body": body_text, # For deleted content, just use raw body
+ "preview_body": preview_text,
+ "title": thread_title, # Use parent thread title for comments/responses
+ "course_id": comment_data.get("course_id", ""),
+ "author": author_username,
+ "author_id": comment_data.get("author_id", ""),
+ "author_label": get_user_label_fn(comment_data.get("author_id")),
+ "thread_id": str(thread_id),
+ "comment_thread_id": str(thread_id),
+ "thread_title": thread_title,
+ "parent_id": (
+ str(comment_data.get("parent_id", ""))
+ if comment_data.get("parent_id")
+ else None
+ ),
+ "created_at": comment_data.get("created_at"),
+ "updated_at": comment_data.get("updated_at"),
+ "depth": depth,
+ "anonymous": comment_data.get("anonymous", False),
+ "anonymous_to_peers": comment_data.get("anonymous_to_peers", False),
+ "endorsed": comment_data.get("endorsed", False),
+ "endorsed_by": comment_data.get("endorsed_by"),
+ "endorsed_by_label": None,
+ "endorsed_at": comment_data.get("endorsed_at"),
+ "abuse_flagged": len(abuse_flaggers) > 0 if abuse_flaggers else False,
+ "abuse_flagged_count": abuse_flagged_count,
+ "voted": False, # Cannot vote on deleted content
+ "vote_count": vote_count,
+ "editable_fields": [], # Deleted content is not editable
+ "can_delete": False, # Already deleted
+ "child_count": comment_data.get("child_count", 0),
+ "is_deleted": True,
+ "deleted_at": comment_data.get("deleted_at"),
+ "deleted_by": deleted_by_username,
+ "deleted_by_label": get_user_label_fn(deleted_by_id) if deleted_by_id else None,
+ }
+
+
+def _add_user_profiles_to_content(deleted_content, usernames_set, request):
+ """
+ Fetch user profile images and add them to each content item.
+
+ Args:
+ deleted_content: List of content items (modified in-place)
+ usernames_set: Set of usernames to fetch profile images for
+ request: Django request object for getting profile images
+ """
+ # Add profile images for all users
+ username_profile_dict = _get_user_profile_dict(
+ request, usernames=",".join(usernames_set)
+ )
+
+ # Add users dict with profile images to each item
+ for item in deleted_content:
+ users_dict = {}
+
+ # Add author profile
+ author_username = item.get("author")
+ if author_username and author_username in username_profile_dict:
+ users_dict[author_username] = _user_profile(
+ username_profile_dict[author_username]
+ )
+
+ # Add deleted_by profile
+ deleted_by_username = item.get("deleted_by")
+ if deleted_by_username and deleted_by_username in username_profile_dict:
+ users_dict[deleted_by_username] = _user_profile(
+ username_profile_dict[deleted_by_username]
+ )
+
+ item["users"] = users_dict
+
+
+def get_deleted_content_for_course(
+ request, course_id, content_type=None, page=1, per_page=20, author_id=None
+):
+ """
+ Retrieve all deleted content (threads, comments) for a course.
+
+ Args:
+ request: The django request object for getting user profile images
+ course_id (str): Course identifier
+ content_type (str, optional): Filter by 'thread' or 'comment'. If None, returns all types.
+ page (int): Page number for pagination (1-based)
+ per_page (int): Number of items per page
+ author_id (str, optional): Filter by author ID
+
+ Returns:
+ dict: Paginated results with deleted content including author labels and profile images
+ """
+
+ import math
+
+ from lms.djangoapps.discussion.rest_api.utils import (
+ get_course_staff_users_list,
+ get_course_ta_users_list,
+ get_moderator_users_list,
+ )
+
+ try:
+ # Get course and user role information for labels
+ course_key = CourseKey.from_string(course_id)
+ course = _get_course(course_key, request.user)
+
+ course_staff_user_ids = get_course_staff_users_list(course.id)
+ moderator_user_ids = get_moderator_users_list(course.id)
+ ta_user_ids = get_course_ta_users_list(course.id)
+
+ # Get user label function
+ get_user_label = _get_user_label_function(
+ course_staff_user_ids, moderator_user_ids, ta_user_ids
+ )
+
+ # Build query parameters for forum API
+ query_params = {
+ "course_id": course_id,
+ "is_deleted": True, # Only get deleted content
+ "page": page,
+ "per_page": per_page,
+ }
+
+ if author_id:
+ query_params["author_id"] = author_id
+
+ deleted_content = []
+ total_count = 0
+ usernames_set = set() # Track all usernames for profile image fetch
+
+ # Get deleted threads
+ if content_type is None or content_type == "thread":
+ try:
+ deleted_threads = forum_api.get_deleted_threads_for_course(
+ course_id=course_id,
+ page=page if content_type == "thread" else 1,
+ per_page=per_page if content_type == "thread" else 1000,
+ author_id=author_id,
+ )
+ for thread_data in deleted_threads.get("threads", []):
+ content_item = _process_deleted_thread(
+ thread_data, get_user_label, usernames_set
+ )
+ deleted_content.append(content_item)
+
+ if content_type == "thread":
+ total_count = deleted_threads.get(
+ "total_count", len(deleted_content)
+ )
+ except Exception as e: # pylint: disable=broad-exception-caught
+ log.warning(
+ "Failed to get deleted threads for course %s: %s", course_id, e
+ )
+
+ # Get deleted comments
+ if content_type is None or content_type == "comment":
+ try:
+ deleted_comments = forum_api.get_deleted_comments_for_course(
+ course_id=course_id,
+ page=page if content_type == "comment" else 1,
+ per_page=per_page if content_type == "comment" else 1000,
+ author_id=author_id,
+ )
+ for comment_data in deleted_comments.get("comments", []):
+ content_item = _process_deleted_comment(
+ comment_data, get_user_label, usernames_set
+ )
+ deleted_content.append(content_item)
+
+ if content_type == "comment":
+ total_count = deleted_comments.get(
+ "total_count", len(deleted_content)
+ )
+ except Exception as e: # pylint: disable=broad-exception-caught
+ log.warning(
+ "Failed to get deleted comments for course %s: %s", course_id, e
+ )
+
+ # If getting all content types, handle pagination differently
+ if content_type is None:
+ total_count = len(deleted_content)
+ # Sort by deletion date (most recent first)
+ deleted_content.sort(key=lambda x: x.get("deleted_at", ""), reverse=True)
+
+ # Apply pagination to combined results
+ start_idx = (page - 1) * per_page
+ end_idx = start_idx + per_page
+ deleted_content = deleted_content[start_idx:end_idx]
+
+ # Add profile images for all users
+ _add_user_profiles_to_content(deleted_content, usernames_set, request)
+
+ # Calculate pagination info
+ num_pages = math.ceil(total_count / per_page) if total_count > 0 else 1
+
+ return {
+ "results": deleted_content,
+ "pagination": {
+ "next": None, # Can be computed if needed
+ "previous": None, # Can be computed if needed
+ "count": total_count,
+ "num_pages": num_pages,
+ },
+ }
+
+ except Exception as e:
+ log.exception("Error getting deleted content for course %s: %s", course_id, e)
+ raise
diff --git a/lms/djangoapps/discussion/rest_api/emails.py b/lms/djangoapps/discussion/rest_api/emails.py
new file mode 100644
index 000000000000..e4ebcc21a567
--- /dev/null
+++ b/lms/djangoapps/discussion/rest_api/emails.py
@@ -0,0 +1,170 @@
+"""
+Email notifications for discussion moderation actions.
+"""
+import logging
+
+from django.conf import settings
+from django.contrib.auth import get_user_model
+
+log = logging.getLogger(__name__)
+User = get_user_model()
+
+# Try to import ACE at module level for easier testing
+try:
+ from edx_ace import ace
+ from edx_ace.recipient import Recipient
+ from edx_ace.message import Message
+ ACE_AVAILABLE = True
+except ImportError:
+ ace = None
+ Recipient = None
+ Message = None
+ ACE_AVAILABLE = False
+
+
+def send_ban_escalation_email(
+ banned_user_id,
+ moderator_id,
+ course_id,
+ scope,
+ reason,
+ threads_deleted,
+ comments_deleted
+):
+ """
+ Send email to partner-support when user is banned.
+
+ Uses ACE (Automated Communications Engine) for templated emails if available,
+ otherwise falls back to Django's email system.
+
+ Args:
+ banned_user_id: ID of the banned user
+ moderator_id: ID of the moderator who applied the ban
+ course_id: Course ID where ban was applied
+ scope: 'course' or 'organization'
+ reason: Reason for the ban
+ threads_deleted: Number of threads deleted
+ comments_deleted: Number of comments deleted
+ """
+ # Check if email notifications are enabled
+ if not getattr(settings, 'DISCUSSION_MODERATION_BAN_EMAIL_ENABLED', True):
+ log.info(
+ "Ban email notifications disabled by settings. "
+ "User %s banned in course %s (scope: %s)",
+ banned_user_id, course_id, scope
+ )
+ return
+
+ try:
+ banned_user = User.objects.get(id=banned_user_id)
+ moderator = User.objects.get(id=moderator_id)
+
+ # Get escalation email from settings
+ escalation_email = getattr(
+ settings,
+ 'DISCUSSION_MODERATION_ESCALATION_EMAIL',
+ 'partner-support@edx.org'
+ )
+
+ # Try using ACE first (preferred method for edX)
+ if ACE_AVAILABLE and ace is not None:
+ message = Message(
+ app_label='discussion',
+ name='ban_escalation',
+ recipient=Recipient(lms_user_id=None, email_address=escalation_email),
+ context={
+ 'banned_username': banned_user.username,
+ 'banned_email': banned_user.email,
+ 'banned_user_id': banned_user_id,
+ 'moderator_username': moderator.username,
+ 'moderator_email': moderator.email,
+ 'moderator_id': moderator_id,
+ 'course_id': str(course_id),
+ 'scope': scope,
+ 'reason': reason or 'No reason provided',
+ 'threads_deleted': threads_deleted,
+ 'comments_deleted': comments_deleted,
+ 'total_deleted': threads_deleted + comments_deleted,
+ }
+ )
+
+ ace.send(message)
+ log.info(
+ "Ban escalation email sent via ACE to %s for user %s in course %s",
+ escalation_email, banned_user.username, course_id
+ )
+
+ else:
+ # Fallback to Django's email system if ACE is not available
+ from django.core.mail import send_mail
+ from django.template.loader import render_to_string
+ from django.template import TemplateDoesNotExist
+
+ context = {
+ 'banned_username': banned_user.username,
+ 'banned_email': banned_user.email,
+ 'banned_user_id': banned_user_id,
+ 'moderator_username': moderator.username,
+ 'moderator_email': moderator.email,
+ 'moderator_id': moderator_id,
+ 'course_id': str(course_id),
+ 'scope': scope,
+ 'reason': reason or 'No reason provided',
+ 'threads_deleted': threads_deleted,
+ 'comments_deleted': comments_deleted,
+ 'total_deleted': threads_deleted + comments_deleted,
+ }
+
+ # Try to render template, fall back to plain text if template doesn't exist
+ try:
+ email_body = render_to_string(
+ 'discussion/ban_escalation_email.txt',
+ context
+ )
+ except TemplateDoesNotExist:
+ # Plain text fallback
+ banned_user_info = "{} ({})".format(banned_user.username, banned_user.email)
+ moderator_info = "{} ({})".format(moderator.username, moderator.email)
+ email_body = """
+A user has been banned from discussions:
+
+Banned User: {}
+Moderator: {}
+Course: {}
+Scope: {}
+Reason: {}
+Content Deleted: {} threads, {} comments
+
+Please review this moderation action and follow up as needed.
+""".format(
+ banned_user_info,
+ moderator_info,
+ course_id,
+ scope,
+ reason or 'No reason provided',
+ threads_deleted,
+ comments_deleted
+ )
+
+ subject = f'Discussion Ban Alert: {banned_user.username} in {course_id}'
+ from_email = getattr(settings, 'DEFAULT_FROM_EMAIL', 'no-reply@example.com')
+
+ send_mail(
+ subject=subject,
+ message=email_body,
+ from_email=from_email,
+ recipient_list=[escalation_email],
+ fail_silently=False,
+ )
+
+ log.info(
+ "Ban escalation email sent via Django mail to %s for user %s in course %s",
+ escalation_email, banned_user.username, course_id
+ )
+
+ except User.DoesNotExist as e:
+ log.error("Failed to send ban escalation email: User not found - %s", str(e))
+ raise
+ except Exception as exc:
+ log.error("Failed to send ban escalation email: %s", str(exc), exc_info=True)
+ raise
diff --git a/lms/djangoapps/discussion/rest_api/forms.py b/lms/djangoapps/discussion/rest_api/forms.py
index 8cc7127645b2..f37543723792 100644
--- a/lms/djangoapps/discussion/rest_api/forms.py
+++ b/lms/djangoapps/discussion/rest_api/forms.py
@@ -1,6 +1,7 @@
"""
Discussion API forms
"""
+
import urllib.parse
from django.core.exceptions import ValidationError
@@ -22,13 +23,15 @@
class UserOrdering(TextChoices):
- BY_ACTIVITY = 'activity'
- BY_FLAGS = 'flagged'
- BY_RECENT_ACTIVITY = 'recency'
+ BY_ACTIVITY = "activity"
+ BY_FLAGS = "flagged"
+ BY_RECENT_ACTIVITY = "recency"
+ BY_DELETED = "deleted"
class _PaginationForm(Form):
"""A form that includes pagination fields"""
+
page = IntegerField(required=False, min_value=1)
page_size = IntegerField(required=False, min_value=1)
@@ -45,6 +48,7 @@ class ThreadListGetForm(_PaginationForm):
"""
A form to validate query parameters in the thread list retrieval endpoint
"""
+
EXCLUSIVE_PARAMS = ["topic_id", "text_search", "following"]
course_id = CharField()
@@ -58,17 +62,22 @@ class ThreadListGetForm(_PaginationForm):
)
count_flagged = ExtendedNullBooleanField(required=False)
flagged = ExtendedNullBooleanField(required=False)
+ show_deleted = ExtendedNullBooleanField(required=False)
view = ChoiceField(
- choices=[(choice, choice) for choice in ["unread", "unanswered", "unresponded"]],
+ choices=[
+ (choice, choice) for choice in ["unread", "unanswered", "unresponded"]
+ ],
required=False,
)
order_by = ChoiceField(
- choices=[(choice, choice) for choice in ["last_activity_at", "comment_count", "vote_count"]],
- required=False
+ choices=[
+ (choice, choice)
+ for choice in ["last_activity_at", "comment_count", "vote_count"]
+ ],
+ required=False,
)
order_direction = ChoiceField(
- choices=[(choice, choice) for choice in ["desc"]],
- required=False
+ choices=[(choice, choice) for choice in ["desc"]], required=False
)
requested_fields = MultiValueField(required=False)
@@ -85,14 +94,16 @@ def clean_course_id(self):
value = self.cleaned_data["course_id"]
try:
return CourseLocator.from_string(value)
- except InvalidKeyError:
- raise ValidationError(f"'{value}' is not a valid course id") # lint-amnesty, pylint: disable=raise-missing-from
+ except InvalidKeyError as e:
+ raise ValidationError(f"'{value}' is not a valid course id") from e
def clean_following(self):
"""Validate following"""
value = self.cleaned_data["following"]
if value is False: # lint-amnesty, pylint: disable=no-else-raise
- raise ValidationError("The value of the 'following' parameter must be true.")
+ raise ValidationError(
+ "The value of the 'following' parameter must be true."
+ )
else:
return value
@@ -115,6 +126,7 @@ class ThreadActionsForm(Form):
A form to handle fields in thread creation/update that require separate
interactions with the comments service.
"""
+
following = BooleanField(required=False)
voted = BooleanField(required=False)
abuse_flagged = BooleanField(required=False)
@@ -126,17 +138,20 @@ class CommentListGetForm(_PaginationForm):
"""
A form to validate query parameters in the comment list retrieval endpoint
"""
+
thread_id = CharField()
flagged = BooleanField(required=False)
endorsed = ExtendedNullBooleanField(required=False)
requested_fields = MultiValueField(required=False)
merge_question_type_responses = BooleanField(required=False)
+ show_deleted = ExtendedNullBooleanField(required=False)
class UserCommentListGetForm(_PaginationForm):
"""
A form to validate query parameters in the comment list retrieval endpoint
"""
+
course_id = CharField()
flagged = BooleanField(required=False)
requested_fields = MultiValueField(required=False)
@@ -146,8 +161,8 @@ def clean_course_id(self):
value = self.cleaned_data["course_id"]
try:
return CourseLocator.from_string(value)
- except InvalidKeyError:
- raise ValidationError(f"'{value}' is not a valid course id") # lint-amnesty, pylint: disable=raise-missing-from
+ except InvalidKeyError as e:
+ raise ValidationError(f"'{value}' is not a valid course id") from e
class CommentActionsForm(Form):
@@ -155,6 +170,7 @@ class CommentActionsForm(Form):
A form to handle fields in comment creation/update that require separate
interactions with the comments service.
"""
+
voted = BooleanField(required=False)
abuse_flagged = BooleanField(required=False)
@@ -163,6 +179,7 @@ class CommentGetForm(_PaginationForm):
"""
A form to validate query parameters in the comment retrieval endpoint
"""
+
requested_fields = MultiValueField(required=False)
@@ -170,28 +187,34 @@ class CourseDiscussionSettingsForm(Form):
"""
A form to validate the fields in the course discussion settings requests.
"""
+
course_id = CharField()
def __init__(self, *args, **kwargs):
- self.request_user = kwargs.pop('request_user')
+ self.request_user = kwargs.pop("request_user")
super().__init__(*args, **kwargs)
def clean_course_id(self):
"""Validate the 'course_id' value"""
- course_id = self.cleaned_data['course_id']
+ course_id = self.cleaned_data["course_id"]
try:
course_key = CourseKey.from_string(course_id)
- self.cleaned_data['course'] = get_course_with_access(self.request_user, 'load', course_key)
- self.cleaned_data['course_key'] = course_key
+ self.cleaned_data["course"] = get_course_with_access(
+ self.request_user, "load", course_key
+ )
+ self.cleaned_data["course_key"] = course_key
return course_id
- except InvalidKeyError:
- raise ValidationError(f"'{str(course_id)}' is not a valid course key") # lint-amnesty, pylint: disable=raise-missing-from
+ except InvalidKeyError as e:
+ raise ValidationError(
+ f"'{str(course_id)}' is not a valid course key"
+ ) from e
class CourseDiscussionRolesForm(CourseDiscussionSettingsForm):
"""
A form to validate the fields in the course discussion roles requests.
"""
+
ROLE_CHOICES = (
(FORUM_ROLE_MODERATOR, FORUM_ROLE_MODERATOR),
(FORUM_ROLE_COMMUNITY_TA, FORUM_ROLE_MODERATOR),
@@ -199,20 +222,20 @@ class CourseDiscussionRolesForm(CourseDiscussionSettingsForm):
)
rolename = ChoiceField(
choices=ROLE_CHOICES,
- error_messages={"invalid_choice": "Role '%(value)s' does not exist"}
+ error_messages={"invalid_choice": "Role '%(value)s' does not exist"},
)
def clean_rolename(self):
"""Validate the 'rolename' value."""
- rolename = urllib.parse.unquote(self.cleaned_data.get('rolename'))
- course_id = self.cleaned_data.get('course_key')
+ rolename = urllib.parse.unquote(self.cleaned_data.get("rolename"))
+ course_id = self.cleaned_data.get("course_key")
if course_id and rolename:
try:
role = Role.objects.get(name=rolename, course_id=course_id)
except Role.DoesNotExist as err:
raise ValidationError(f"Role '{rolename}' does not exist") from err
- self.cleaned_data['role'] = role
+ self.cleaned_data["role"] = role
return rolename
@@ -220,15 +243,17 @@ class TopicListGetForm(Form):
"""
Form for the topics API get query parameters.
"""
+
topic_id = CharField(required=False)
order_by = ChoiceField(choices=TopicOrdering.choices, required=False)
def clean_topic_id(self):
topic_ids = self.cleaned_data.get("topic_id", None)
- return set(topic_ids.strip(',').split(',')) if topic_ids else None
+ return set(topic_ids.strip(",").split(",")) if topic_ids else None
class CourseActivityStatsForm(_PaginationForm):
"""Form for validating course activity stats API query parameters"""
+
order_by = ChoiceField(choices=UserOrdering.choices, required=False)
username = CharField(required=False)
diff --git a/lms/djangoapps/discussion/rest_api/permissions.py b/lms/djangoapps/discussion/rest_api/permissions.py
index cfcea5b32834..a3aa6398915b 100644
--- a/lms/djangoapps/discussion/rest_api/permissions.py
+++ b/lms/djangoapps/discussion/rest_api/permissions.py
@@ -3,10 +3,11 @@
"""
from typing import Dict, Set, Union
+from opaque_keys import InvalidKeyError
from opaque_keys.edx.keys import CourseKey
from rest_framework import permissions
-from common.djangoapps.student.models import CourseAccessRole, CourseEnrollment
+from common.djangoapps.student.models import CourseEnrollment
from common.djangoapps.student.roles import (
CourseInstructorRole,
CourseStaffRole,
@@ -189,42 +190,137 @@ def has_permission(self, request, view):
def can_take_action_on_spam(user, course_id):
"""
- Returns if the user has access to take action against forum spam posts
+ Returns if the user has access to take action against forum spam posts.
+
+ Grants access to:
+ - Global Staff (user.is_staff or GlobalStaff role)
+ - Course Staff for the specific course
+ - Course Instructors for the specific course
+ - Forum Moderators for the specific course
+ - Forum Administrators for the specific course
+
Parameters:
user: User object
course_id: CourseKey or string of course_id
+
+ Returns:
+ bool: True if user can take action on spam, False otherwise
"""
- if GlobalStaff().has_user(user):
+ # Global staff have universal access
+ if GlobalStaff().has_user(user) or user.is_staff:
return True
if isinstance(course_id, str):
course_id = CourseKey.from_string(course_id)
- org_id = course_id.org
- course_ids = CourseEnrollment.objects.filter(user=user).values_list('course_id', flat=True)
- course_ids = [c_id for c_id in course_ids if c_id.org == org_id]
+
+ # Check if user is Course Staff or Instructor for this specific course
+ if CourseStaffRole(course_id).has_user(user):
+ return True
+
+ if CourseInstructorRole(course_id).has_user(user):
+ return True
+
+ # Check forum moderator/administrator roles for this specific course
user_roles = set(
Role.objects.filter(
users=user,
- course_id__in=course_ids,
- ).values_list('name', flat=True).distinct()
+ course_id=course_id,
+ ).values_list('name', flat=True)
)
- if bool(user_roles & {FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR}):
- return True
- if CourseAccessRole.objects.filter(user=user, course_id__in=course_ids, role__in=["instructor", "staff"]).exists():
+ if user_roles & {FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR}:
return True
+
return False
class IsAllowedToBulkDelete(permissions.BasePermission):
"""
- Permission that checks if the user is staff or an admin.
+ Permission that checks if the user is allowed to perform bulk delete and ban operations.
+
+ Grants access to:
+ - Global Staff (superusers)
+ - Course Staff
+ - Course Instructors
+ - Forum Moderators
+ - Forum Administrators
+
+ Denies access to:
+ - Unauthenticated users
+ - Regular students
+ - Community TAs (they can moderate individual posts but not bulk delete)
"""
def has_permission(self, request, view):
- """Returns true if the user can bulk delete posts"""
+ """
+ Returns True if the user can bulk delete posts and ban users.
+
+ For ViewSet actions, course_id may come from:
+ 1. URL kwargs (view.kwargs.get('course_id')) - for URL path parameters
+ 2. Request body (request.data.get('course_id')) - for POST request bodies
+ """
if not request.user.is_authenticated:
return False
- course_id = view.kwargs.get("course_id")
+ # Try to get course_id from URL kwargs or request data
+ course_id = (
+ view.kwargs.get("course_id") or
+ (request.data.get("course_id") if hasattr(request, 'data') else None)
+ )
+
+ # If no course_id provided, we can't check permissions yet
+ # Let the view handle validation of required course_id
+ if not course_id:
+ # For safety, only allow global staff to proceed without course_id
+ return GlobalStaff().has_user(request.user) or request.user.is_staff
+
return can_take_action_on_spam(request.user, course_id)
+
+
+class IsAllowedToRestore(permissions.BasePermission):
+ """
+ Permission that checks if the user has privileges to restore individual deleted content.
+
+ This permission is intentionally more permissive than IsAllowedToBulkDelete because:
+ - Restoring individual content is a less risky operation than bulk deletion
+ - Users who can see deleted content should be able to restore it
+ - Course-level moderation staff need this capability for day-to-day moderation
+
+ Allowed users (course-level permissions):
+ - Global staff (platform-wide)
+ - Course instructors
+ - Course staff
+ - Discussion moderators (course-specific)
+ - Discussion community TAs (course-specific)
+ - Discussion administrators (course-specific)
+ """
+
+ def has_permission(self, request, view):
+ """Returns true if the user can restore deleted posts"""
+ if not request.user.is_authenticated:
+ return False
+
+ # For restore operations, course_id is in request.data, not URL kwargs
+ course_id = request.data.get("course_id")
+ if not course_id:
+ return False
+
+ # Global staff always has permission
+ if GlobalStaff().has_user(request.user):
+ return True
+
+ try:
+ course_key = CourseKey.from_string(course_id)
+ except InvalidKeyError:
+ return False
+
+ # Check if user is course staff or instructor
+ if CourseStaffRole(course_key).has_user(request.user) or \
+ CourseInstructorRole(course_key).has_user(request.user):
+ return True
+
+ # Check if user has discussion privileges (moderator, community TA, administrator)
+ if has_discussion_privileges(request.user, course_key):
+ return True
+
+ return False
diff --git a/lms/djangoapps/discussion/rest_api/serializers.py b/lms/djangoapps/discussion/rest_api/serializers.py
index 9c2668d0b226..1f7f2264cd19 100644
--- a/lms/djangoapps/discussion/rest_api/serializers.py
+++ b/lms/djangoapps/discussion/rest_api/serializers.py
@@ -1,13 +1,13 @@
"""
Discussion API serializers
"""
+
import html
import re
-
-from bs4 import BeautifulSoup
from typing import Dict
from urllib.parse import urlencode, urlunparse
+from bs4 import BeautifulSoup
from django.conf import settings
from django.contrib.auth import get_user_model
from django.core.exceptions import ObjectDoesNotExist, ValidationError
@@ -18,8 +18,12 @@
from common.djangoapps.student.models import get_user_by_username_or_email
from common.djangoapps.student.roles import GlobalStaff
-from lms.djangoapps.discussion.django_comment_client.base.views import track_thread_lock_unlock_event, \
- track_thread_edited_event, track_comment_edited_event, track_forum_response_mark_event
+from lms.djangoapps.discussion.django_comment_client.base.views import (
+ track_comment_edited_event,
+ track_forum_response_mark_event,
+ track_thread_edited_event,
+ track_thread_lock_unlock_event,
+)
from lms.djangoapps.discussion.django_comment_client.utils import (
course_discussion_division_enabled,
get_group_id_for_user,
@@ -35,16 +39,23 @@
from lms.djangoapps.discussion.rest_api.render import render_body
from lms.djangoapps.discussion.rest_api.utils import (
get_course_staff_users_list,
- get_moderator_users_list,
get_course_ta_users_list,
+ get_moderator_users_list,
+ get_user_learner_status,
)
from openedx.core.djangoapps.discussions.models import DiscussionTopicLink
from openedx.core.djangoapps.discussions.utils import get_group_names_by_id
from openedx.core.djangoapps.django_comment_common.comment_client.comment import Comment
from openedx.core.djangoapps.django_comment_common.comment_client.thread import Thread
-from openedx.core.djangoapps.django_comment_common.comment_client.user import User as CommentClientUser
-from openedx.core.djangoapps.django_comment_common.comment_client.utils import CommentClientRequestError
-from openedx.core.djangoapps.django_comment_common.models import CourseDiscussionSettings
+from openedx.core.djangoapps.django_comment_common.comment_client.user import (
+ User as CommentClientUser,
+)
+from openedx.core.djangoapps.django_comment_common.comment_client.utils import (
+ CommentClientRequestError,
+)
+from openedx.core.djangoapps.django_comment_common.models import (
+ CourseDiscussionSettings,
+)
from openedx.core.djangoapps.user_api.accounts.api import get_profile_images
from openedx.core.lib.api.serializers import CourseKeyField
@@ -58,6 +69,7 @@ class TopicOrdering(TextChoices):
"""
Enum for the available options for ordering topics.
"""
+
COURSE_STRUCTURE = "course_structure", "Course Structure"
ACTIVITY = "activity", "Activity"
NAME = "name", "Name"
@@ -72,16 +84,22 @@ def get_context(course, request, thread=None):
moderator_user_ids = get_moderator_users_list(course.id)
ta_user_ids = get_course_ta_users_list(course.id)
requester = request.user
- cc_requester = CommentClientUser.from_django_user(requester).retrieve(course_id=course.id)
+ cc_requester = CommentClientUser.from_django_user(requester).retrieve(
+ course_id=course.id
+ )
cc_requester["course_id"] = course.id
course_discussion_settings = CourseDiscussionSettings.get(course.id)
is_global_staff = GlobalStaff().has_user(requester)
- has_moderation_privilege = requester.id in moderator_user_ids or requester.id in ta_user_ids or is_global_staff
+ all_privileged_ids = set(moderator_user_ids) | set(ta_user_ids) | set(course_staff_user_ids)
+ has_moderation_privilege = requester.id in all_privileged_ids or is_global_staff
return {
"course": course,
+ "course_id": course.id,
"request": request,
"thread": thread,
- "discussion_division_enabled": course_discussion_division_enabled(course_discussion_settings),
+ "discussion_division_enabled": course_discussion_division_enabled(
+ course_discussion_settings
+ ),
"group_ids_to_names": get_group_names_by_id(course_discussion_settings),
"moderator_user_ids": moderator_user_ids,
"course_staff_user_ids": course_staff_user_ids,
@@ -136,8 +154,8 @@ def _validate_privileged_access(context: Dict) -> bool:
Returns:
bool: Course exists and the user has privileged access.
"""
- course = context.get('course', None)
- is_requester_privileged = context.get('has_moderation_privilege')
+ course = context.get("course", None)
+ is_requester_privileged = context.get("has_moderation_privilege")
return course and is_requester_privileged
@@ -157,7 +175,7 @@ def filter_spam_urls_from_html(html_string):
patterns.append(re.compile(rf"(https?://)?{domain_pattern}", re.IGNORECASE))
for a_tag in soup.find_all("a", href=True):
- href = a_tag.get('href')
+ href = a_tag.get("href")
if href:
if any(p.search(href) for p in patterns):
a_tag.replace_with(a_tag.get_text(strip=True))
@@ -166,7 +184,7 @@ def filter_spam_urls_from_html(html_string):
for text_node in soup.find_all(string=True):
new_text = text_node
for p in patterns:
- new_text = p.sub('', new_text)
+ new_text = p.sub("", new_text)
if new_text != text_node:
text_node.replace_with(new_text.strip())
is_spam = True
@@ -182,6 +200,7 @@ class _ContentSerializer(serializers.Serializer):
id = serializers.CharField(read_only=True) # pylint: disable=invalid-name
author = serializers.SerializerMethodField()
author_label = serializers.SerializerMethodField()
+ learner_status = serializers.SerializerMethodField()
created_at = serializers.CharField(read_only=True)
updated_at = serializers.CharField(read_only=True)
raw_body = serializers.CharField(source="body", validators=[validate_not_blank])
@@ -194,8 +213,16 @@ class _ContentSerializer(serializers.Serializer):
anonymous = serializers.BooleanField(default=False)
anonymous_to_peers = serializers.BooleanField(default=False)
last_edit = serializers.SerializerMethodField(required=False)
- edit_reason_code = serializers.CharField(required=False, validators=[validate_edit_reason_code])
+ edit_reason_code = serializers.CharField(
+ required=False, validators=[validate_edit_reason_code]
+ )
edit_by_label = serializers.SerializerMethodField(required=False)
+ is_deleted = serializers.SerializerMethodField(read_only=True)
+ deleted_at = serializers.SerializerMethodField(read_only=True)
+ deleted_by = serializers.SerializerMethodField(read_only=True)
+ deleted_by_label = serializers.SerializerMethodField(read_only=True)
+ is_author_banned = serializers.SerializerMethodField(read_only=True)
+ author_ban_scope = serializers.SerializerMethodField(read_only=True)
non_updatable_fields = set()
@@ -217,7 +244,10 @@ def _is_user_privileged(self, user_id):
Returns a boolean indicating whether the given user_id identifies a
privileged user.
"""
- return user_id in self.context["moderator_user_ids"] or user_id in self.context["ta_user_ids"]
+ return (
+ user_id in self.context["moderator_user_ids"]
+ or user_id in self.context["ta_user_ids"]
+ )
def _is_anonymous(self, obj):
"""
@@ -225,13 +255,13 @@ def _is_anonymous(self, obj):
the requester.
"""
user_id = self.context["request"].user.id
- is_user_staff = user_id in self.context["moderator_user_ids"] or user_id in self.context["ta_user_ids"]
-
- return (
- obj["anonymous"] or
- obj["anonymous_to_peers"] and not is_user_staff
+ is_user_staff = (
+ user_id in self.context["moderator_user_ids"]
+ or user_id in self.context["ta_user_ids"]
)
+ return obj["anonymous"] or obj["anonymous_to_peers"] and not is_user_staff
+
def get_author(self, obj):
"""
Returns the author's username, or None if the content is anonymous.
@@ -248,10 +278,9 @@ def _get_user_label(self, user_id):
is_ta = user_id in self.context["ta_user_ids"]
return (
- "Staff" if is_staff else
- "Moderator" if is_moderator else
- "Community TA" if is_ta else
- None
+ "Staff"
+ if is_staff
+ else "Moderator" if is_moderator else "Community TA" if is_ta else None
)
def _get_user_label_from_username(self, username):
@@ -275,13 +304,35 @@ def get_author_label(self, obj):
user_id = int(obj["user_id"])
return self._get_user_label(user_id)
+ def get_learner_status(self, obj):
+ """
+ Get the learner status for the discussion post author.
+ Returns one of: "anonymous", "staff", "new", "regular"
+ """
+ # Skip for anonymous content
+ if self._is_anonymous(obj) or obj.get("user_id") is None:
+ return "anonymous"
+
+ try:
+ user = User.objects.get(id=int(obj["user_id"]))
+ except (User.DoesNotExist, ValueError):
+ return "anonymous"
+
+ course = self.context.get("course")
+ if not course:
+ return "anonymous"
+
+ return get_user_learner_status(user, course.id)
+
def get_rendered_body(self, obj):
"""
Returns the rendered body content.
"""
if self._rendered_body is None:
self._rendered_body = render_body(obj["body"])
- self._rendered_body, is_spam = filter_spam_urls_from_html(self._rendered_body)
+ self._rendered_body, is_spam = filter_spam_urls_from_html(
+ self._rendered_body
+ )
if is_spam and settings.CONTENT_FOR_SPAM_POSTS:
self._rendered_body = settings.CONTENT_FOR_SPAM_POSTS
return self._rendered_body
@@ -293,8 +344,9 @@ def get_abuse_flagged(self, obj):
"""
total_abuse_flaggers = len(obj.get("abuse_flaggers", []))
return (
- self.context["has_moderation_privilege"] and total_abuse_flaggers > 0 or
- self.context["cc_requester"]["id"] in obj.get("abuse_flaggers", [])
+ self.context["has_moderation_privilege"]
+ and total_abuse_flaggers > 0
+ or self.context["cc_requester"]["id"] in obj.get("abuse_flaggers", [])
)
def get_voted(self, obj):
@@ -327,7 +379,7 @@ def get_last_edit(self, obj):
Returns information about the last edit for this content for
privileged users.
"""
- is_user_author = str(obj['user_id']) == str(self.context['request'].user.id)
+ is_user_author = str(obj["user_id"]) == str(self.context["request"].user.id)
if not (_validate_privileged_access(self.context) or is_user_author):
return None
edit_history = obj.get("edit_history")
@@ -343,12 +395,233 @@ def get_edit_by_label(self, obj):
"""
Returns the role label for the last edit user.
"""
- is_user_author = str(obj['user_id']) == str(self.context['request'].user.id)
+ is_user_author = str(obj["user_id"]) == str(self.context["request"].user.id)
is_user_privileged = _validate_privileged_access(self.context)
edit_history = obj.get("edit_history")
if (is_user_author or is_user_privileged) and edit_history:
last_edit = edit_history[-1]
- return self._get_user_label_from_username(last_edit.get('editor_username'))
+ return self._get_user_label_from_username(last_edit.get("editor_username"))
+
+ def get_is_deleted(self, obj):
+ """
+ Returns the is_deleted status for privileged users or content authors.
+ """
+ is_user_author = str(obj["user_id"]) == str(self.context["request"].user.id)
+ if not (_validate_privileged_access(self.context) or is_user_author):
+ return None
+ return obj.get("is_deleted", False)
+
+ def get_deleted_at(self, obj):
+ """
+ Returns the deletion timestamp for privileged users or content authors.
+ """
+ is_user_author = str(obj["user_id"]) == str(self.context["request"].user.id)
+ if not (_validate_privileged_access(self.context) or is_user_author):
+ return None
+ return obj.get("deleted_at")
+
+ def get_deleted_by(self, obj):
+ """
+ Returns the username of the user who deleted this content for privileged users or content authors.
+ """
+ is_user_author = str(obj["user_id"]) == str(self.context["request"].user.id)
+ if not (_validate_privileged_access(self.context) or is_user_author):
+ return None
+ deleted_by_id = obj.get("deleted_by")
+ if deleted_by_id:
+ try:
+ user = User.objects.get(id=int(deleted_by_id))
+ return user.username
+ except (User.DoesNotExist, ValueError):
+ return None
+ return None
+
+ def get_deleted_by_label(self, obj):
+ """
+ Returns the role label for the user who deleted this content for privileged users only.
+ """
+ if not _validate_privileged_access(self.context):
+ return None
+ deleted_by_id = obj.get("deleted_by")
+ if deleted_by_id:
+ try:
+ return self._get_user_label(int(deleted_by_id))
+ except (ValueError, TypeError):
+ return None
+ return None
+
+ def _get_author_ban_cache_key(self, course_id, user_id):
+ """Build a stable cache key for author ban lookups."""
+ return (str(course_id), int(user_id))
+
+ def _get_author_from_cache(self, user_id):
+ """Fetch author from per-request cache or database."""
+ user_cache = self.context.setdefault("_author_ban_user_cache", {})
+ if user_id not in user_cache:
+ try:
+ user_cache[user_id] = User.objects.get(id=user_id)
+ except User.DoesNotExist:
+ user_cache[user_id] = None
+ return user_cache[user_id]
+
+ def get_is_author_banned(self, obj):
+ """
+ Returns True if the content author is banned from discussions.
+ Returns False for anonymous content or if ban check fails.
+ """
+ from forum import api as forum_api
+ from lms.djangoapps.discussion.toggles import ENABLE_DISCUSSION_BAN
+
+ # Skip for anonymous content
+ if self._is_anonymous(obj) or obj.get("user_id") is None:
+ return False
+
+ # Skip if ban function not available
+ is_user_banned_func = getattr(forum_api, 'is_user_banned', None)
+ if not is_user_banned_func:
+ return False
+
+ # Skip if feature flag is not enabled
+ course_id = self.context.get("course_id")
+ if not course_id or not ENABLE_DISCUSSION_BAN.is_enabled(course_id):
+ return False
+
+ try:
+ user_id = int(obj["user_id"])
+ except (ValueError, TypeError):
+ return False
+
+ cache_key = self._get_author_ban_cache_key(course_id, user_id)
+ ban_status_cache = self.context.setdefault("_author_ban_status_cache", {})
+ if cache_key in ban_status_cache:
+ return ban_status_cache[cache_key]
+
+ try:
+ user = self._get_author_from_cache(user_id)
+ if not user:
+ ban_status_cache[cache_key] = False
+ return False
+
+ is_banned = is_user_banned_func(user, course_id)
+ ban_status_cache[cache_key] = is_banned
+ return is_banned
+ except (User.DoesNotExist, ValueError, Exception): # pylint: disable=broad-except
+ ban_status_cache[cache_key] = False
+
+ return False
+
+ def get_author_ban_scope(self, obj):
+ """
+ Returns the scope of the author's ban ('course' or 'organization').
+ Returns None for anonymous content, unbanned users, or if check fails.
+ """
+ from forum import api as forum_api
+ from lms.djangoapps.discussion.toggles import ENABLE_DISCUSSION_BAN
+ import logging
+ logger = logging.getLogger(__name__)
+
+ # Skip for anonymous content
+ if self._is_anonymous(obj) or obj.get("user_id") is None:
+ return None
+
+ # Skip if required functions not available
+ is_user_banned_func = getattr(forum_api, 'is_user_banned', None)
+ get_user_bans_func = getattr(forum_api, 'get_user_bans', None)
+ if not is_user_banned_func:
+ return None
+
+ # Skip if feature flag is not enabled
+ course_id = self.context.get("course_id")
+ if not course_id or not ENABLE_DISCUSSION_BAN.is_enabled(course_id):
+ return None
+
+ try:
+ user_id = int(obj["user_id"])
+ except (ValueError, TypeError):
+ return None
+
+ cache_key = self._get_author_ban_cache_key(course_id, user_id)
+ ban_scope_cache = self.context.setdefault("_author_ban_scope_cache", {})
+ if cache_key in ban_scope_cache:
+ return ban_scope_cache[cache_key]
+
+ ban_status_cache = self.context.setdefault("_author_ban_status_cache", {})
+
+ try:
+ user = self._get_author_from_cache(user_id)
+ if not user:
+ ban_scope_cache[cache_key] = None
+ return None
+
+ if not course_id:
+ ban_scope_cache[cache_key] = None
+ return None
+
+ # First check if user is banned at all
+ user_banned = ban_status_cache.get(cache_key)
+ if user_banned is None:
+ user_banned = is_user_banned_func(user, course_id)
+ ban_status_cache[cache_key] = user_banned
+
+ if not user_banned:
+ ban_scope_cache[cache_key] = None
+ return None
+
+ # Try to get all active bans for this user and course
+ if get_user_bans_func:
+ try:
+ bans = get_user_bans_func(user=user, course_id=course_id)
+ # Check for organization-level ban first (higher precedence)
+ for ban in bans:
+ if ban.get('is_active') and ban.get('scope') == 'organization':
+ ban_scope_cache[cache_key] = 'organization'
+ return 'organization'
+ # Then check for course-level ban
+ for ban in bans:
+ if ban.get('is_active') and ban.get('scope') == 'course':
+ ban_scope_cache[cache_key] = 'course'
+ return 'course'
+ except Exception as e: # pylint: disable=broad-except
+ logger.debug(
+ "Unable to fetch ban list for ban-scope detection. course_id=%s user_id=%s error=%s",
+ course_id,
+ obj.get("user_id"),
+ e,
+ )
+
+ # Fallback: Try checking each scope individually using is_user_banned
+ # check_org parameter: True = include org checks, False = course-only
+ try:
+ # Check course-only (check_org=False means don't check org)
+ course_only = is_user_banned_func(user, course_id, check_org=False)
+
+ # If course-only check returns False but user IS banned, must be org-banned
+ if not course_only:
+ ban_scope_cache[cache_key] = 'organization'
+ return 'organization'
+
+ # If course-only check returns True, it's course-level ban
+ ban_scope_cache[cache_key] = 'course'
+ return 'course'
+ except TypeError as e:
+ # check_org parameter might not exist in older versions
+ logger.debug(
+ "check_org parameter unsupported during ban-scope detection. course_id=%s user_id=%s error=%s",
+ course_id,
+ obj.get("user_id"),
+ e,
+ )
+
+ except (User.DoesNotExist, ValueError, Exception) as e: # pylint: disable=broad-except
+ logger.warning(
+ "Unable to determine author ban scope. course_id=%s user_id=%s error=%s",
+ self.context.get("course_id"),
+ obj.get("user_id"),
+ e,
+ )
+
+ ban_scope_cache[cache_key] = None
+ return None
class ThreadSerializer(_ContentSerializer):
@@ -359,13 +632,15 @@ class ThreadSerializer(_ContentSerializer):
not had retrieve() called, because of the interaction between DRF's attempts
at introspection and Thread's __getattr__.
"""
+
course_id = serializers.CharField()
- topic_id = serializers.CharField(source="commentable_id", validators=[validate_not_blank])
+ topic_id = serializers.CharField(
+ source="commentable_id", validators=[validate_not_blank]
+ )
group_id = serializers.IntegerField(required=False, allow_null=True)
group_name = serializers.SerializerMethodField()
type = serializers.ChoiceField(
- source="thread_type",
- choices=[(val, val) for val in ["discussion", "question"]]
+ source="thread_type", choices=[(val, val) for val in ["discussion", "question"]]
)
preview_body = serializers.SerializerMethodField()
abuse_flagged_count = serializers.SerializerMethodField(required=False)
@@ -380,8 +655,12 @@ class ThreadSerializer(_ContentSerializer):
non_endorsed_comment_list_url = serializers.SerializerMethodField()
read = serializers.BooleanField(required=False)
has_endorsed = serializers.BooleanField(source="endorsed", read_only=True)
- response_count = serializers.IntegerField(source="resp_total", read_only=True, required=False)
- close_reason_code = serializers.CharField(required=False, validators=[validate_close_reason_code])
+ response_count = serializers.IntegerField(
+ source="resp_total", read_only=True, required=False
+ )
+ close_reason_code = serializers.CharField(
+ required=False, validators=[validate_close_reason_code]
+ )
close_reason = serializers.SerializerMethodField()
closed_by = serializers.SerializerMethodField()
closed_by_label = serializers.SerializerMethodField(required=False)
@@ -427,9 +706,8 @@ def get_comment_list_url(self, obj, endorsed=None):
Returns the URL to retrieve the thread's comments, optionally including
the endorsed query parameter.
"""
- if (
- (obj["thread_type"] == "question" and endorsed is None) or
- (obj["thread_type"] == "discussion" and endorsed is not None)
+ if (obj["thread_type"] == "question" and endorsed is None) or (
+ obj["thread_type"] == "discussion" and endorsed is not None
):
return None
path = reverse("comment-list")
@@ -473,13 +751,17 @@ def get_preview_body(self, obj):
"""
Returns a cleaned version of the thread's body to display in a preview capacity.
"""
- return strip_tags(self.get_rendered_body(obj)).replace('\n', ' ').replace(' ', ' ')
+ return (
+ strip_tags(self.get_rendered_body(obj))
+ .replace("\n", " ")
+ .replace(" ", " ")
+ )
def get_close_reason(self, obj):
"""
Returns the reason for which the thread was closed.
"""
- is_user_author = str(obj['user_id']) == str(self.context['request'].user.id)
+ is_user_author = str(obj["user_id"]) == str(self.context["request"].user.id)
if not (_validate_privileged_access(self.context) or is_user_author):
return None
reason_code = obj.get("close_reason_code")
@@ -490,7 +772,7 @@ def get_closed_by(self, obj):
Returns the username of the moderator who closed this thread,
only to other privileged users and author.
"""
- is_user_author = str(obj['user_id']) == str(self.context['request'].user.id)
+ is_user_author = str(obj["user_id"]) == str(self.context["request"].user.id)
if _validate_privileged_access(self.context) or is_user_author:
return obj.get("closed_by")
@@ -498,7 +780,7 @@ def get_closed_by_label(self, obj):
"""
Returns the role label for the user who closed the post.
"""
- is_user_author = str(obj['user_id']) == str(self.context['request'].user.id)
+ is_user_author = str(obj["user_id"]) == str(self.context["request"].user.id)
if is_user_author or _validate_privileged_access(self.context):
return self._get_user_label_from_username(obj.get("closed_by"))
@@ -513,18 +795,31 @@ def update(self, instance, validated_data):
requesting_user_id = self.context["cc_requester"]["id"]
if key == "closed" and val:
instance["closing_user_id"] = requesting_user_id
- track_thread_lock_unlock_event(self.context['request'], self.context['course'],
- instance, validated_data.get('close_reason_code'))
+ track_thread_lock_unlock_event(
+ self.context["request"],
+ self.context["course"],
+ instance,
+ validated_data.get("close_reason_code"),
+ )
if key == "closed" and not val:
instance["closing_user_id"] = requesting_user_id
- track_thread_lock_unlock_event(self.context['request'], self.context['course'],
- instance, validated_data.get('close_reason_code'), locked=False)
+ track_thread_lock_unlock_event(
+ self.context["request"],
+ self.context["course"],
+ instance,
+ validated_data.get("close_reason_code"),
+ locked=False,
+ )
if key == "body" and val:
instance["editing_user_id"] = requesting_user_id
- track_thread_edited_event(self.context['request'], self.context['course'],
- instance, validated_data.get('edit_reason_code'))
+ track_thread_edited_event(
+ self.context["request"],
+ self.context["course"],
+ instance,
+ validated_data.get("edit_reason_code"),
+ )
instance.save()
return instance
@@ -537,6 +832,7 @@ class CommentSerializer(_ContentSerializer):
not had retrieve() called, because of the interaction between DRF's attempts
at introspection and Comment's __getattr__.
"""
+
thread_id = serializers.CharField()
parent_id = serializers.CharField(required=False, allow_null=True)
endorsed = serializers.BooleanField(required=False)
@@ -551,7 +847,7 @@ class CommentSerializer(_ContentSerializer):
non_updatable_fields = NON_UPDATABLE_COMMENT_FIELDS
def __init__(self, *args, **kwargs):
- remove_fields = kwargs.pop('remove_fields', None)
+ remove_fields = kwargs.pop("remove_fields", None)
super().__init__(*args, **kwargs)
if remove_fields:
@@ -573,8 +869,8 @@ def get_endorsed_by(self, obj):
# Avoid revealing the identity of an anonymous non-staff question
# author who has endorsed a comment in the thread
if not (
- self._is_anonymous(self.context["thread"]) and
- not self._is_user_privileged(endorser_id)
+ self._is_anonymous(self.context["thread"])
+ and not self._is_user_privileged(endorser_id)
):
return User.objects.get(id=endorser_id).username
return None
@@ -616,7 +912,7 @@ def to_representation(self, data):
# Django Rest Framework v3 no longer includes None values
# in the representation. To maintain the previous behavior,
# we do this manually instead.
- if 'parent_id' not in data:
+ if "parent_id" not in data:
data["parent_id"] = None
return data
@@ -658,7 +954,7 @@ def create(self, validated_data):
comment = Comment(
course_id=self.context["thread"]["course_id"],
user_id=self.context["cc_requester"]["id"],
- **validated_data
+ **validated_data,
)
comment.save()
return comment
@@ -671,12 +967,18 @@ def update(self, instance, validated_data):
# endorsement_user_id on update
requesting_user_id = self.context["cc_requester"]["id"]
if key == "endorsed":
- track_forum_response_mark_event(self.context['request'], self.context['course'], instance, val)
+ track_forum_response_mark_event(
+ self.context["request"], self.context["course"], instance, val
+ )
instance["endorsement_user_id"] = requesting_user_id
if key == "body" and val:
instance["editing_user_id"] = requesting_user_id
- track_comment_edited_event(self.context['request'], self.context['course'],
- instance, validated_data.get('edit_reason_code'))
+ track_comment_edited_event(
+ self.context["request"],
+ self.context["course"],
+ instance,
+ validated_data.get("edit_reason_code"),
+ )
instance.save()
return instance
@@ -686,6 +988,7 @@ class DiscussionTopicSerializer(serializers.Serializer):
"""
Serializer for DiscussionTopic
"""
+
id = serializers.CharField(read_only=True) # pylint: disable=invalid-name
name = serializers.CharField(read_only=True)
thread_list_url = serializers.CharField(read_only=True)
@@ -715,10 +1018,11 @@ class DiscussionTopicSerializerV2(serializers.Serializer):
"""
Serializer for new style topics.
"""
+
id = serializers.CharField( # pylint: disable=invalid-name
read_only=True,
source="external_id",
- help_text="Provider-specific unique id for the topic"
+ help_text="Provider-specific unique id for the topic",
)
usage_key = serializers.CharField(
read_only=True,
@@ -742,10 +1046,13 @@ def get_thread_counts(self, obj: DiscussionTopicLink) -> Dict[str, int]:
"""
Get thread counts from provided context
"""
- return self.context['thread_counts'].get(obj.external_id, {
- "discussion": 0,
- "question": 0,
- })
+ return self.context["thread_counts"].get(
+ obj.external_id,
+ {
+ "discussion": 0,
+ "question": 0,
+ },
+ )
class DiscussionRolesSerializer(serializers.Serializer):
@@ -753,10 +1060,7 @@ class DiscussionRolesSerializer(serializers.Serializer):
Serializer for course discussion roles.
"""
- ACTION_CHOICES = (
- ('allow', 'allow'),
- ('revoke', 'revoke')
- )
+ ACTION_CHOICES = (("allow", "allow"), ("revoke", "revoke"))
action = serializers.ChoiceField(ACTION_CHOICES)
user_id = serializers.CharField()
@@ -777,14 +1081,16 @@ def validate_user_id(self, user_id):
self.user = get_user_by_username_or_email(user_id)
return user_id
except User.DoesNotExist as err:
- raise ValidationError(f"'{user_id}' is not a valid student identifier") from err
+ raise ValidationError(
+ f"'{user_id}' is not a valid student identifier"
+ ) from err
def validate(self, attrs):
"""Validate the data at an object level."""
# Store the user object to avoid fetching it again.
- if hasattr(self, 'user'):
- attrs['user'] = self.user
+ if hasattr(self, "user"):
+ attrs["user"] = self.user
return attrs
def create(self, validated_data):
@@ -802,6 +1108,7 @@ class DiscussionRolesMemberSerializer(serializers.Serializer):
"""
Serializer for course discussion roles member data.
"""
+
username = serializers.CharField()
email = serializers.EmailField()
first_name = serializers.CharField()
@@ -810,7 +1117,7 @@ class DiscussionRolesMemberSerializer(serializers.Serializer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.course_discussion_settings = self.context['course_discussion_settings']
+ self.course_discussion_settings = self.context["course_discussion_settings"]
def get_group_name(self, instance):
"""Return the group name of the user."""
@@ -833,6 +1140,7 @@ class DiscussionRolesListSerializer(serializers.Serializer):
"""
Serializer for course discussion roles member list.
"""
+
course_id = serializers.CharField()
results = serializers.SerializerMethodField()
division_scheme = serializers.SerializerMethodField()
@@ -840,15 +1148,17 @@ class DiscussionRolesListSerializer(serializers.Serializer):
def get_results(self, obj):
"""Return the nested serializer data representing a list of member users."""
context = {
- 'course_id': obj['course_id'],
- 'course_discussion_settings': self.context['course_discussion_settings']
+ "course_id": obj["course_id"],
+ "course_discussion_settings": self.context["course_discussion_settings"],
}
- serializer = DiscussionRolesMemberSerializer(obj['users'], context=context, many=True)
+ serializer = DiscussionRolesMemberSerializer(
+ obj["users"], context=context, many=True
+ )
return serializer.data
def get_division_scheme(self, obj): # pylint: disable=unused-argument
"""Return the division scheme for the course."""
- return self.context['course_discussion_settings'].division_scheme
+ return self.context["course_discussion_settings"].division_scheme
def create(self, validated_data):
"""
@@ -865,9 +1175,13 @@ class UserStatsSerializer(serializers.Serializer):
"""
Serializer for course user stats.
"""
+
threads = serializers.IntegerField()
replies = serializers.IntegerField()
responses = serializers.IntegerField()
+ deleted_threads = serializers.IntegerField(required=False, default=0)
+ deleted_replies = serializers.IntegerField(required=False, default=0)
+ deleted_responses = serializers.IntegerField(required=False, default=0)
active_flags = serializers.IntegerField()
inactive_flags = serializers.IntegerField()
username = serializers.CharField()
@@ -885,27 +1199,36 @@ class BlackoutDateSerializer(serializers.Serializer):
"""
Serializer for blackout dates.
"""
- start = serializers.DateTimeField(help_text="The ISO 8601 timestamp for the start of the blackout period")
- end = serializers.DateTimeField(help_text="The ISO 8601 timestamp for the end of the blackout period")
+
+ start = serializers.DateTimeField(
+ help_text="The ISO 8601 timestamp for the start of the blackout period"
+ )
+ end = serializers.DateTimeField(
+ help_text="The ISO 8601 timestamp for the end of the blackout period"
+ )
class ReasonCodeSeralizer(serializers.Serializer):
"""
Serializer for reason codes.
"""
+
code = serializers.CharField(help_text="A code for the an edit or close reason")
- label = serializers.CharField(help_text="A user-friendly name text for the close or edit reason")
+ label = serializers.CharField(
+ help_text="A user-friendly name text for the close or edit reason"
+ )
class CourseMetadataSerailizer(serializers.Serializer):
"""
Serializer for course metadata.
"""
+
id = CourseKeyField(help_text="The identifier of the course")
blackouts = serializers.ListField(
child=BlackoutDateSerializer(),
help_text="A list of objects representing blackout periods "
- "(during which discussions are read-only except for privileged users)."
+ "(during which discussions are read-only except for privileged users).",
)
thread_list_url = serializers.URLField(
help_text="The URL of the list of all threads in the course.",
@@ -913,7 +1236,9 @@ class CourseMetadataSerailizer(serializers.Serializer):
following_thread_list_url = serializers.URLField(
help_text="thread_list_url with parameter following=True",
)
- topics_url = serializers.URLField(help_text="The URL of the topic listing for the course.")
+ topics_url = serializers.URLField(
+ help_text="The URL of the topic listing for the course."
+ )
allow_anonymous = serializers.BooleanField(
help_text="A boolean indicating whether anonymous posts are allowed or not.",
)
@@ -944,3 +1269,141 @@ class CourseMetadataSerailizer(serializers.Serializer):
child=ReasonCodeSeralizer(),
help_text="A list of reasons that can be specified by moderators for editing a post, response, or comment",
)
+
+
+class BulkDeleteBanRequestSerializer(serializers.Serializer):
+ """
+ Request payload for bulk delete + ban action.
+
+ Accepts either user_id (for programmatic access) or username (for UI/human convenience).
+ Internally normalizes to user_id before processing.
+ """
+
+ user_id = serializers.IntegerField(
+ required=False,
+ help_text="User ID to ban. Either user_id or username must be provided."
+ )
+ username = serializers.CharField(
+ required=False,
+ max_length=150,
+ help_text="Username to ban. Converted to user_id internally. Either user_id or username must be provided."
+ )
+ course_id = serializers.CharField(max_length=255, required=True)
+ ban_user = serializers.BooleanField(default=False)
+ ban_scope = serializers.ChoiceField(
+ choices=['course', 'organization'],
+ default='course',
+ help_text="Scope of the ban: 'course' for course-level or 'organization' for organization-level"
+ )
+ reason = serializers.CharField(
+ required=False,
+ allow_blank=True,
+ max_length=1000
+ )
+
+ def validate(self, data):
+ """
+ Validate and normalize user identification.
+
+ - Ensures either user_id or username is provided
+ - Converts username to user_id if needed
+ - Validates ban requirements (reason, permissions)
+ """
+ # Validate that either user_id or username is provided
+ if not data.get('user_id') and not data.get('username'):
+ raise serializers.ValidationError({
+ 'user_id': "Either user_id or username must be provided."
+ })
+
+ # Normalize username to user_id for internal processing
+ # This allows the view/task to always work with user_id
+ if data.get('username') and not data.get('user_id'):
+ try:
+ user = User.objects.get(username=data['username'])
+ data['user_id'] = user.id
+ # Keep username for logging/audit purposes
+ data['resolved_username'] = user.username
+ except User.DoesNotExist as exc:
+ raise serializers.ValidationError({
+ 'username': f"User with username '{data['username']}' does not exist."
+ }) from exc
+ elif data.get('user_id'):
+ # If user_id provided directly, resolve username for consistency
+ try:
+ user = User.objects.get(id=data['user_id'])
+ data['resolved_username'] = user.username
+ except User.DoesNotExist as exc:
+ raise serializers.ValidationError({
+ 'user_id': f"User with ID {data['user_id']} does not exist."
+ }) from exc
+
+ if data.get('ban_user'):
+ reason = data.get('reason', '').strip()
+ if not reason:
+ raise serializers.ValidationError({
+ 'reason': "Reason is required when banning a user."
+ })
+
+ # Validate that organization-level bans require elevated permissions
+ # only when a ban is requested.
+ if data.get('ban_user') and data.get('ban_scope') == 'organization':
+ request = self.context.get('request')
+ if request and not (
+ GlobalStaff().has_user(request.user) or request.user.is_staff
+ ):
+ raise serializers.ValidationError({
+ 'ban_scope': "Organization-level bans require global staff permissions."
+ })
+
+ return data
+
+
+class BanUserRequestSerializer(serializers.Serializer):
+ """
+ Request payload for standalone ban action (without bulk delete).
+
+ For direct ban from UI moderation actions.
+ """
+
+ user_id = serializers.IntegerField(
+ required=False,
+ help_text="User ID to ban. Either user_id or username must be provided."
+ )
+ username = serializers.CharField(
+ required=False,
+ max_length=150,
+ help_text="Username to ban. Converted to user_id internally. Either user_id or username must be provided."
+ )
+ course_id = serializers.CharField(
+ max_length=255,
+ required=True,
+ help_text="Course ID for course-level bans or org context for organization-level bans"
+ )
+ scope = serializers.ChoiceField(
+ choices=['course', 'organization'],
+ default='course',
+ help_text="Scope of the ban: 'course' for course-level or 'organization' for organization-level"
+ )
+ reason = serializers.CharField(
+ required=False,
+ allow_blank=True,
+ max_length=1000,
+ help_text="Reason for the ban (optional)"
+ )
+
+ def validate(self, data):
+ """
+ Validate and normalize user identification.
+ """
+ # Validate that either user_id or username is provided
+ if not data.get('user_id') and not data.get('username'):
+ raise serializers.ValidationError({
+ 'user_id': "Either user_id or username must be provided."
+ })
+
+ # Normalize username to user_id if provided (view will validate existence)
+ if data.get('username') and not data.get('user_id'):
+ # Don't validate user existence here - let the view return 404
+ # Just record the username for the view to resolve
+ data['lookup_username'] = data['username']
+ return data
diff --git a/lms/djangoapps/discussion/rest_api/tasks.py b/lms/djangoapps/discussion/rest_api/tasks.py
index cd725a3513dc..3bc96b654288 100644
--- a/lms/djangoapps/discussion/rest_api/tasks.py
+++ b/lms/djangoapps/discussion/rest_api/tasks.py
@@ -1,32 +1,36 @@
"""
Contain celery tasks
"""
+
import logging
from celery import shared_task
from django.contrib.auth import get_user_model
from edx_django_utils.monitoring import set_code_owner_attribute
-from opaque_keys.edx.locator import CourseKey
from eventtracking import tracker
+from opaque_keys.edx.keys import CourseKey
-from common.djangoapps.student.roles import CourseStaffRole, CourseInstructorRole
+from common.djangoapps.student.roles import CourseInstructorRole, CourseStaffRole
from common.djangoapps.track import segment
from lms.djangoapps.courseware.courses import get_course_with_access
from lms.djangoapps.discussion.django_comment_client.utils import get_user_role_names
-from lms.djangoapps.discussion.rest_api.discussions_notifications import DiscussionNotificationSender
+from lms.djangoapps.discussion.rest_api.discussions_notifications import (
+ DiscussionNotificationSender,
+)
from lms.djangoapps.discussion.rest_api.utils import can_user_notify_all_learners
from openedx.core.djangoapps.django_comment_common.comment_client import Comment
from openedx.core.djangoapps.django_comment_common.comment_client.thread import Thread
from openedx.core.djangoapps.notifications.config.waffle import ENABLE_NOTIFICATIONS
-
User = get_user_model()
log = logging.getLogger(__name__)
@shared_task
@set_code_owner_attribute
-def send_thread_created_notification(thread_id, course_key_str, user_id, notify_all_learners=False):
+def send_thread_created_notification(
+ thread_id, course_key_str, user_id, notify_all_learners=False
+):
"""
Send notification when a new thread is created
"""
@@ -40,17 +44,21 @@ def send_thread_created_notification(thread_id, course_key_str, user_id, notify_
is_course_staff = CourseStaffRole(course_key).has_user(user)
is_course_admin = CourseInstructorRole(course_key).has_user(user)
user_roles = get_user_role_names(user, course_key)
- if not can_user_notify_all_learners(user_roles, is_course_staff, is_course_admin):
+ if not can_user_notify_all_learners(
+ user_roles, is_course_staff, is_course_admin
+ ):
return
- course = get_course_with_access(user, 'load', course_key, check_if_enrolled=True)
+ course = get_course_with_access(user, "load", course_key, check_if_enrolled=True)
notification_sender = DiscussionNotificationSender(thread, course, user)
notification_sender.send_new_thread_created_notification(notify_all_learners)
@shared_task
@set_code_owner_attribute
-def send_response_notifications(thread_id, course_key_str, user_id, comment_id, parent_id=None):
+def send_response_notifications(
+ thread_id, course_key_str, user_id, comment_id, parent_id=None
+):
"""
Send notifications to users who are subscribed to the thread.
"""
@@ -59,8 +67,10 @@ def send_response_notifications(thread_id, course_key_str, user_id, comment_id,
return
thread = Thread(id=thread_id).retrieve()
user = User.objects.get(id=user_id)
- course = get_course_with_access(user, 'load', course_key, check_if_enrolled=True)
- notification_sender = DiscussionNotificationSender(thread, course, user, parent_id, comment_id)
+ course = get_course_with_access(user, "load", course_key, check_if_enrolled=True)
+ notification_sender = DiscussionNotificationSender(
+ thread, course, user, parent_id, comment_id
+ )
notification_sender.send_new_comment_notification()
notification_sender.send_new_response_notification()
notification_sender.send_new_comment_on_response_notification()
@@ -69,7 +79,9 @@ def send_response_notifications(thread_id, course_key_str, user_id, comment_id,
@shared_task
@set_code_owner_attribute
-def send_response_endorsed_notifications(thread_id, response_id, course_key_str, endorsed_by):
+def send_response_endorsed_notifications(
+ thread_id, response_id, course_key_str, endorsed_by
+):
"""
Send notifications when a response is marked answered/ endorsed
"""
@@ -80,8 +92,10 @@ def send_response_endorsed_notifications(thread_id, response_id, course_key_str,
response = Comment(id=response_id).retrieve()
creator = User.objects.get(id=response.user_id)
endorser = User.objects.get(id=endorsed_by)
- course = get_course_with_access(creator, 'load', course_key, check_if_enrolled=True)
- notification_sender = DiscussionNotificationSender(thread, course, creator, comment_id=response_id)
+ course = get_course_with_access(creator, "load", course_key, check_if_enrolled=True)
+ notification_sender = DiscussionNotificationSender(
+ thread, course, creator, comment_id=response_id
+ )
# skip sending notification to author of thread if they are the same as the author of the response
if response.user_id != thread.user_id:
# sends notification to author of thread
@@ -92,22 +106,177 @@ def send_response_endorsed_notifications(thread_id, response_id, course_key_str,
notification_sender.send_response_endorsed_notification()
+@shared_task(
+ bind=True, # Enable retry context and access to task instance
+ max_retries=3, # Retry up to 3 times on failure
+ default_retry_delay=60, # Wait 60 seconds between retries
+ autoretry_for=(OSError, TimeoutError), # Only retry on transient network/IO errors
+ retry_backoff=True, # Exponential backoff between retries
+ retry_jitter=True, # Add randomization to retry delays
+)
+@set_code_owner_attribute
+def delete_course_post_for_user( # pylint: disable=too-many-statements
+ self,
+ user_id,
+ username=None,
+ course_ids=None,
+ event_data=None,
+ # NEW PARAMETERS (backward compatible - all have defaults):
+ ban_user=False,
+ ban_scope='course',
+ moderator_id=None,
+ reason=None,
+):
+ """
+ Delete all discussion posts for a user and optionally ban them.
+
+ BACKWARD COMPATIBLE: Existing callers without ban_user parameter
+ will experience no change in behavior.
+
+ Args:
+ self: Task instance (when bind=True)
+ user_id: User whose posts to delete
+ username: Username of the user (optional, will be fetched if not provided)
+ course_ids: List of course IDs (API sends single course wrapped in array)
+ event_data: Event tracking metadata
+ ban_user: If True, create ban record (NEW)
+ ban_scope: 'course' or 'organization' (NEW)
+ moderator_id: Moderator applying ban (NEW)
+ reason: Ban reason (NEW)
+ """
+ event_data = event_data or {}
+ log.info(
+ f"<> Deleting all posts for {username} in course {course_ids}"
+ )
+ # Get triggered_by user_id from event_data for audit trail
+ deleted_by_user_id = event_data.get("triggered_by_user_id") if event_data else None
+ threads_deleted = Thread.delete_user_threads(
+ user_id, course_ids, deleted_by=deleted_by_user_id
+ )
+ comments_deleted = Comment.delete_user_comments(
+ user_id, course_ids, deleted_by=deleted_by_user_id
+ )
+ log.info(
+ f"<> Deleted {threads_deleted} posts and {comments_deleted} comments for {username} "
+ f"in course {course_ids}"
+ )
+
+ # Create ban record if requested
+ ban_id = None
+ ban_error = None
+ if ban_user:
+ try:
+ from forum import api as forum_api
+
+ # Get user objects
+ target_user = User.objects.get(id=user_id)
+ moderator = User.objects.get(id=moderator_id) if moderator_id else None
+
+ # Parse course key
+ course_key = CourseKey.from_string(course_ids[0]) if course_ids else None
+
+ # Create ban using forum API
+ ban_result = forum_api.ban_user(
+ user=target_user,
+ banned_by=moderator,
+ course_id=course_key,
+ scope=ban_scope,
+ reason=reason or "Bulk delete and ban operation"
+ )
+
+ ban_id = ban_result.get('id')
+
+ log.info(
+ f"<> Created {ban_scope}-level ban (ID: {ban_id}) "
+ f"for user {username} (ID: {user_id}) after deleting {threads_deleted + comments_deleted} items"
+ )
+
+ # Send escalation email (non-blocking)
+ try:
+ from lms.djangoapps.discussion.rest_api.emails import send_ban_escalation_email
+
+ send_ban_escalation_email(
+ banned_user_id=user_id,
+ moderator_id=moderator_id,
+ course_id=course_ids[0] if course_ids else None,
+ scope=ban_scope,
+ reason=reason,
+ threads_deleted=threads_deleted,
+ comments_deleted=comments_deleted,
+ )
+ except Exception as email_exc: # pylint: disable=broad-except
+ log.error(
+ "<> Failed to send ban escalation email for user %s (ID: %s): %s",
+ username,
+ user_id,
+ email_exc,
+ exc_info=True,
+ )
+
+ except Exception as e: # pylint: disable=broad-except
+ ban_error = str(e)
+ log.error(
+ f"<> Failed to create ban for user {username} (ID: {user_id}): {e}",
+ exc_info=True
+ )
+ # Don't fail the entire task if ban creation fails
+ # Discussions are already deleted, so we log the error and continue
+
+ event_data.update(
+ {
+ "number_of_posts_deleted": threads_deleted,
+ "number_of_comments_deleted": comments_deleted,
+ "ban_user": ban_user,
+ "ban_scope": ban_scope if ban_user else None,
+ "ban_id": ban_id if ban_user else None,
+ "ban_error": ban_error if ban_error else None,
+ }
+ )
+ event_name = "edx.discussion.bulk_delete_user_posts"
+ tracker.emit(event_name, event_data)
+ segment.track("None", event_name, event_data)
+
+ # Return task result for monitoring
+ return {
+ "threads_deleted": threads_deleted,
+ "comments_deleted": comments_deleted,
+ "ban_created": bool(ban_id),
+ "ban_id": ban_id,
+ "ban_error": ban_error,
+ }
+
+
@shared_task
@set_code_owner_attribute
-def delete_course_post_for_user(user_id, username, course_ids, event_data=None):
+def restore_course_post_for_user(user_id, username, course_ids, event_data=None):
"""
- Deletes all posts for user in a course.
+ Restores all soft-deleted posts for user in a course by setting is_deleted=False.
"""
event_data = event_data or {}
- log.info(f"<> Deleting all posts for {username} in course {course_ids}")
- threads_deleted = Thread.delete_user_threads(user_id, course_ids)
- comments_deleted = Comment.delete_user_comments(user_id, course_ids)
- log.info(f"<> Deleted {threads_deleted} posts and {comments_deleted} comments for {username} "
- f"in course {course_ids}")
- event_data.update({
- "number_of_posts_deleted": threads_deleted,
- "number_of_comments_deleted": comments_deleted,
- })
- event_name = 'edx.discussion.bulk_delete_user_posts'
+ log.info(
+ "<> Restoring all posts for %s in course %s", username, course_ids
+ )
+ # Get triggered_by user_id from event_data for audit trail
+ restored_by_user_id = event_data.get("triggered_by_user_id") if event_data else None
+ threads_restored = Thread.restore_user_deleted_threads(
+ user_id, course_ids, restored_by=restored_by_user_id
+ )
+ comments_restored = Comment.restore_user_deleted_comments(
+ user_id, course_ids, restored_by=restored_by_user_id
+ )
+ log.info(
+ "<> Restored %s posts and %s comments for %s in course %s",
+ threads_restored,
+ comments_restored,
+ username,
+ course_ids,
+ )
+ event_data.update(
+ {
+ "number_of_posts_restored": threads_restored,
+ "number_of_comments_restored": comments_restored,
+ }
+ )
+ event_name = "edx.discussion.bulk_restore_user_posts"
tracker.emit(event_name, event_data)
- segment.track('None', event_name, event_data)
+ segment.track("None", event_name, event_data)
diff --git a/lms/djangoapps/discussion/rest_api/tests/test_api.py b/lms/djangoapps/discussion/rest_api/tests/test_api.py
index d23a6a06b1b5..9018ad3945d2 100644
--- a/lms/djangoapps/discussion/rest_api/tests/test_api.py
+++ b/lms/djangoapps/discussion/rest_api/tests/test_api.py
@@ -9,6 +9,7 @@
import ddt
import httpretty
import pytest
+from django.core.exceptions import ValidationError
from django.test import override_settings
from django.contrib.auth import get_user_model
from django.test.client import RequestFactory
@@ -30,6 +31,8 @@
from common.djangoapps.util.testing import UrlResetMixin
from lms.djangoapps.discussion.django_comment_client.tests.utils import ForumsEnableMixin
from lms.djangoapps.discussion.rest_api.api import (
+ create_comment,
+ create_thread,
get_course,
get_course_topics,
get_user_comments,
@@ -37,6 +40,7 @@
from lms.djangoapps.discussion.rest_api.exceptions import (
DiscussionDisabledError,
)
+from rest_framework.exceptions import PermissionDenied
from lms.djangoapps.discussion.rest_api.tests.utils import (
CommentsServiceMockMixin,
make_minimal_cs_comment,
@@ -50,6 +54,10 @@
FORUM_ROLE_STUDENT,
Role
)
+from openedx.core.djangoapps.django_comment_common.comment_client.utils import (
+ CommentClient500Error,
+ CommentClientRequestError,
+)
from openedx.core.lib.exceptions import CourseNotFoundError, PageNotFoundError
User = get_user_model()
@@ -132,6 +140,7 @@ def test_basic(self):
assert get_course(self.request, self.course.id) == {
'id': str(self.course.id),
'is_posting_enabled': True,
+ 'is_user_banned': False,
'blackouts': [],
'thread_list_url': 'http://testserver/api/discussion/v1/threads/?course_id=course-v1%3Ax%2By%2Bz',
'following_thread_list_url':
@@ -159,7 +168,8 @@ def test_basic(self):
},
"is_email_verified": True,
"only_verified_users_can_post": False,
- "content_creation_rate_limited": False
+ "content_creation_rate_limited": False,
+ "enable_discussion_ban": False,
}
@ddt.data(
@@ -752,3 +762,117 @@ def test_call_with_non_existent_course(self):
course_key=CourseKey.from_string("course-v1:x+y+z"),
page=2,
)
+
+
+def test_create_thread_denies_banned_user():
+ request = RequestFactory().post('/dummy')
+ request.user = mock.Mock()
+
+ with mock.patch(
+ "lms.djangoapps.discussion.rest_api.api._get_course",
+ return_value=mock.Mock(),
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api.get_context",
+ return_value={},
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api.discussion_open_for_user",
+ return_value=True,
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api._check_initializable_thread_fields",
+ side_effect=ValidationError("downstream validation"),
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api.ENABLE_DISCUSSION_BAN.is_enabled",
+ return_value=True,
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api.forum_api.is_user_banned",
+ return_value=True,
+ create=True,
+ ):
+ with pytest.raises(PermissionDenied, match="You are banned from posting"):
+ create_thread(request, {"course_id": "course-v1:x+y+z"})
+
+
+def test_create_comment_denies_banned_user():
+ request = RequestFactory().post('/dummy')
+ request.user = mock.Mock()
+ course = mock.Mock()
+ course.id = CourseKey.from_string("course-v1:x+y+z")
+
+ with mock.patch(
+ "lms.djangoapps.discussion.rest_api.api._get_thread_and_context",
+ return_value=({"closed": False}, {"course": course}),
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api.discussion_open_for_user",
+ return_value=True,
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api._check_initializable_comment_fields",
+ side_effect=ValidationError("downstream validation"),
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api.ENABLE_DISCUSSION_BAN.is_enabled",
+ return_value=True,
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api.forum_api.is_user_banned",
+ return_value=True,
+ create=True,
+ ):
+ with pytest.raises(PermissionDenied, match="You are banned from posting"):
+ create_comment(request, {"thread_id": "test_thread"})
+
+
+def test_create_thread_ban_check_backend_error_fails_open():
+ request = RequestFactory().post('/dummy')
+ request.user = mock.Mock(id=123)
+
+ with mock.patch(
+ "lms.djangoapps.discussion.rest_api.api._get_course",
+ return_value=mock.Mock(),
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api.get_context",
+ return_value={},
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api.discussion_open_for_user",
+ return_value=True,
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api._check_initializable_thread_fields",
+ side_effect=ValidationError("downstream validation"),
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api.ENABLE_DISCUSSION_BAN.is_enabled",
+ return_value=True,
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api.forum_api.is_user_banned",
+ side_effect=CommentClientRequestError("temporary backend failure"),
+ create=True,
+ ), mock.patch("lms.djangoapps.discussion.rest_api.api.log.warning") as warning_log:
+ with pytest.raises(ValidationError):
+ create_thread(request, {"course_id": "course-v1:x+y+z"})
+
+ warning_log.assert_called_once()
+
+
+def test_create_comment_ban_check_backend_error_fails_open():
+ request = RequestFactory().post('/dummy')
+ request.user = mock.Mock(id=123)
+ course = mock.Mock()
+ course.id = CourseKey.from_string("course-v1:x+y+z")
+
+ with mock.patch(
+ "lms.djangoapps.discussion.rest_api.api._get_thread_and_context",
+ return_value=({"closed": False}, {"course": course}),
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api.discussion_open_for_user",
+ return_value=True,
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api._check_initializable_comment_fields",
+ side_effect=ValidationError("downstream validation"),
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api.ENABLE_DISCUSSION_BAN.is_enabled",
+ return_value=True,
+ ), mock.patch(
+ "lms.djangoapps.discussion.rest_api.api.forum_api.is_user_banned",
+ side_effect=CommentClient500Error("temporary backend failure"),
+ create=True,
+ ), mock.patch("lms.djangoapps.discussion.rest_api.api.log.warning") as warning_log:
+ with pytest.raises(ValidationError):
+ create_comment(request, {"thread_id": "test_thread"})
+
+ warning_log.assert_called_once()
diff --git a/lms/djangoapps/discussion/rest_api/tests/test_api_v2.py b/lms/djangoapps/discussion/rest_api/tests/test_api_v2.py
index f5b15b905639..7f85c2c8c210 100644
--- a/lms/djangoapps/discussion/rest_api/tests/test_api_v2.py
+++ b/lms/djangoapps/discussion/rest_api/tests/test_api_v2.py
@@ -10,34 +10,20 @@
import random
from datetime import datetime, timedelta
from unittest import mock
-from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
import ddt
import httpretty
import pytest
-from django.test import override_settings
from django.contrib.auth import get_user_model
from django.core.exceptions import ValidationError
from django.test.client import RequestFactory
-from opaque_keys.edx.keys import CourseKey
from opaque_keys.edx.locator import CourseLocator
from pytz import UTC
from rest_framework.exceptions import PermissionDenied
-from xmodule.modulestore import ModuleStoreEnum
-from xmodule.modulestore.django import modulestore
-from xmodule.modulestore.tests.django_utils import (
- ModuleStoreTestCase,
- SharedModuleStoreTestCase,
-)
-from xmodule.modulestore.tests.factories import CourseFactory, BlockFactory
-from xmodule.partitions.partitions import Group, UserPartition
-
from common.djangoapps.student.tests.factories import (
AdminFactory,
- BetaTesterFactory,
CourseEnrollmentFactory,
- StaffFactory,
UserFactory,
)
from common.djangoapps.util.testing import UrlResetMixin
@@ -45,10 +31,6 @@
from lms.djangoapps.discussion.django_comment_client.tests.utils import (
ForumsEnableMixin,
)
-from lms.djangoapps.discussion.tests.utils import (
- make_minimal_cs_comment,
- make_minimal_cs_thread,
-)
from lms.djangoapps.discussion.rest_api import api
from lms.djangoapps.discussion.rest_api.api import (
create_comment,
@@ -56,12 +38,9 @@
delete_comment,
delete_thread,
get_comment_list,
- get_course,
- get_course_topics,
get_course_topics_v2,
get_thread,
get_thread_list,
- get_user_comments,
update_comment,
update_thread,
)
@@ -73,18 +52,19 @@
)
from lms.djangoapps.discussion.rest_api.serializers import TopicOrdering
from lms.djangoapps.discussion.rest_api.tests.utils import (
- CommentsServiceMockMixin,
ForumMockUtilsMixin,
make_paginated_api_response,
- parsed_body,
)
-from openedx.core.djangoapps.course_groups.models import CourseUserGroupPartitionGroup
+from lms.djangoapps.discussion.tests.utils import (
+ make_minimal_cs_comment,
+ make_minimal_cs_thread,
+)
from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory
from openedx.core.djangoapps.discussions.models import (
DiscussionsConfiguration,
DiscussionTopicLink,
- Provider,
PostingRestriction,
+ Provider,
)
from openedx.core.djangoapps.discussions.tasks import (
update_discussions_settings_from_course_task,
@@ -98,6 +78,13 @@
Role,
)
from openedx.core.lib.exceptions import CourseNotFoundError, PageNotFoundError
+from xmodule.modulestore import ModuleStoreEnum
+from xmodule.modulestore.django import modulestore
+from xmodule.modulestore.tests.django_utils import (
+ ModuleStoreTestCase,
+ SharedModuleStoreTestCase,
+)
+from xmodule.modulestore.tests.factories import BlockFactory, CourseFactory
User = get_user_model()
@@ -274,7 +261,11 @@ def test_basic(self, mock_emit):
)
self.register_post_thread_response(cs_thread)
with self.assert_signal_sent(
- api, "thread_created", sender=None, user=self.user, exclude_args=("post", "notify_all_learners")
+ api,
+ "thread_created",
+ sender=None,
+ user=self.user,
+ exclude_args=("post", "notify_all_learners"),
):
actual = create_thread(self.request, self.minimal_data)
expected = self.expected_thread_data(
@@ -353,7 +344,11 @@ def test_basic_in_blackout_period_with_user_access(self, mock_emit):
)
with self.assert_signal_sent(
- api, "thread_created", sender=None, user=self.user, exclude_args=("post", "notify_all_learners")
+ api,
+ "thread_created",
+ sender=None,
+ user=self.user,
+ exclude_args=("post", "notify_all_learners"),
):
actual = create_thread(self.request, self.minimal_data)
expected = self.expected_thread_data(
@@ -363,6 +358,7 @@ def test_basic_in_blackout_period_with_user_access(self, mock_emit):
"course_id": str(self.course.id),
"comment_list_url": "http://testserver/api/discussion/v1/comments/?thread_id=test_id",
"read": True,
+ "learner_status": "staff",
"editable_fields": [
"abuse_flagged",
"anonymous",
@@ -378,6 +374,7 @@ def test_basic_in_blackout_period_with_user_access(self, mock_emit):
"type",
"voted",
],
+ "is_deleted": False,
}
)
assert actual == expected
@@ -429,7 +426,11 @@ def test_title_truncation(self, mock_emit):
)
self.register_post_thread_response(cs_thread)
with self.assert_signal_sent(
- api, "thread_created", sender=None, user=self.user, exclude_args=("post", "notify_all_learners")
+ api,
+ "thread_created",
+ sender=None,
+ user=self.user,
+ exclude_args=("post", "notify_all_learners"),
):
create_thread(self.request, data)
event_name, event_data = mock_emit.call_args[0]
@@ -689,6 +690,9 @@ def test_success(self, parent_id, mock_emit):
"parent_id": parent_id,
"author": self.user.username,
"author_label": None,
+ "is_author_banned": False,
+ "author_ban_scope": None,
+ "learner_status": "new",
"created_at": "2015-05-27T00:00:00Z",
"updated_at": "2015-05-27T00:00:00Z",
"raw_body": "Test body",
@@ -716,6 +720,10 @@ def test_success(self, parent_id, mock_emit):
"image_url_medium": "http://testserver/static/default_50.png",
"image_url_small": "http://testserver/static/default_30.png",
},
+ "is_deleted": False,
+ "deleted_at": None,
+ "deleted_by": None,
+ "deleted_by_label": None,
}
assert actual == expected
@@ -796,6 +804,9 @@ def test_success_in_black_out_with_user_access(self, parent_id, mock_emit):
"parent_id": parent_id,
"author": self.user.username,
"author_label": "Moderator",
+ "is_author_banned": False,
+ "author_ban_scope": None,
+ "learner_status": "staff",
"created_at": "2015-05-27T00:00:00Z",
"updated_at": "2015-05-27T00:00:00Z",
"raw_body": "Test body",
@@ -823,6 +834,10 @@ def test_success_in_black_out_with_user_access(self, parent_id, mock_emit):
"image_url_medium": "http://testserver/static/default_50.png",
"image_url_small": "http://testserver/static/default_30.png",
},
+ "is_deleted": False,
+ "deleted_at": None,
+ "deleted_by": None,
+ "deleted_by_label": None,
}
assert actual == expected
@@ -911,7 +926,9 @@ def test_endorsed(self, role_name, is_thread_author, thread_type):
)
try:
create_comment(self.request, data)
- last_commemt_params = self.get_mock_func_calls("create_parent_comment")[-1][1]
+ last_commemt_params = self.get_mock_func_calls("create_parent_comment")[-1][
+ 1
+ ]
assert last_commemt_params["endorsed"]
assert not expected_error
except ValidationError:
@@ -1799,6 +1816,9 @@ def test_basic(self, parent_id):
"parent_id": parent_id,
"author": self.user.username,
"author_label": None,
+ "is_author_banned": False,
+ "author_ban_scope": None,
+ "learner_status": "new",
"created_at": "2015-06-03T00:00:00Z",
"updated_at": "2015-06-03T00:00:00Z",
"raw_body": "Edited body",
@@ -1824,6 +1844,10 @@ def test_basic(self, parent_id):
"image_url_medium": "http://testserver/static/default_50.png",
"image_url_small": "http://testserver/static/default_30.png",
},
+ "is_deleted": False,
+ "deleted_at": None,
+ "deleted_by": None,
+ "deleted_by_label": None,
}
assert actual == expected
params = {
@@ -1884,7 +1908,7 @@ def test_abuse_flagged(self, old_flagged, new_flagged, mock_emit):
else "edx.forum.response.unreported"
)
expected_event_data = {
- "discussion": {'id': 'test_thread'},
+ "discussion": {"id": "test_thread"},
"body": "Original body",
"id": "test_comment",
"content_type": "Response",
@@ -1947,7 +1971,7 @@ def test_comment_un_abuse_flag_for_moderator_role(
"body": "Original body",
"id": "test_comment",
"content_type": "Response",
- "discussion": {'id': 'test_thread'},
+ "discussion": {"id": "test_thread"},
"commentable_id": "dummy",
"truncated": False,
"url": "",
@@ -2366,6 +2390,7 @@ def test_basic(self, mock_emit):
params = {
"thread_id": self.thread_id,
"course_id": str(self.course.id),
+ "deleted_by": str(self.user.id),
}
self.check_mock_called_with("delete_thread", -1, **params)
@@ -2553,6 +2578,7 @@ def test_basic(self, mock_emit):
params = {
"comment_id": self.comment_id,
"course_id": str(self.course.id),
+ "deleted_by": str(self.user.id),
}
self.check_mock_called_with("delete_comment", -1, **params)
@@ -2722,6 +2748,7 @@ def register_thread(self, overrides=None):
"title": "Test Title",
"body": "Test body",
"resp_total": 0,
+ "is_deleted": False,
}
)
cs_data.update(overrides or {})
@@ -2756,6 +2783,7 @@ def test_nonauthor_enrolled_in_course(self):
"voted",
],
"unread_comment_count": 1,
+ "is_deleted": None,
}
)
self.check_mock_called("get_thread")
@@ -2917,6 +2945,7 @@ def test_get_threads_by_topic_id(self):
"page": 1,
"per_page": 1,
"commentable_ids": ["topic_x", "topic_meow"],
+ "show_deleted": False,
}
self.check_mock_called_with(
"get_user_threads",
@@ -2932,6 +2961,7 @@ def test_basic_query_params(self):
"sort_key": "activity",
"page": 6,
"per_page": 14,
+ "show_deleted": False,
}
self.check_mock_called_with(
"get_user_threads",
@@ -2959,6 +2989,7 @@ def test_thread_content(self):
"read": True,
"created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z",
+ "is_deleted": False,
}
),
make_minimal_cs_thread(
@@ -2976,6 +3007,7 @@ def test_thread_content(self):
"comments_count": 18,
"created_at": "2015-04-28T22:22:22Z",
"updated_at": "2015-04-28T00:33:33Z",
+ "is_deleted": False,
}
),
]
@@ -3002,6 +3034,7 @@ def test_thread_content(self):
"updated_at": "2015-04-28T11:11:11Z",
"abuse_flagged_count": None,
"can_delete": False,
+ "is_deleted": None,
}
),
self.expected_thread_data(
@@ -3036,6 +3069,7 @@ def test_thread_content(self):
],
"abuse_flagged_count": None,
"can_delete": False,
+ "is_deleted": None,
}
),
]
@@ -3072,10 +3106,10 @@ def test_request_group(self, role_name, course_is_cohorted):
self.get_thread_list([], course=cohort_course)
thread_func_params = self.get_mock_func_calls("get_user_threads")[-1][1]
actual_has_group = "group_id" in thread_func_params
- expected_has_group = (
- course_is_cohorted and role_name in (
- FORUM_ROLE_STUDENT, FORUM_ROLE_COMMUNITY_TA, FORUM_ROLE_GROUP_MODERATOR
- )
+ expected_has_group = course_is_cohorted and role_name in (
+ FORUM_ROLE_STUDENT,
+ FORUM_ROLE_COMMUNITY_TA,
+ FORUM_ROLE_GROUP_MODERATOR,
)
assert actual_has_group == expected_has_group
@@ -3140,6 +3174,7 @@ def test_text_search(self, text_search_rewrite):
"page": 1,
"per_page": 10,
"text": "test search string",
+ "show_deleted": False,
}
self.check_mock_called_with(
"search_threads",
@@ -3166,6 +3201,7 @@ def test_filter_threads_by_author(self):
"page": 1,
"per_page": 10,
"author_id": str(self.user.id),
+ "show_deleted": False,
}
self.check_mock_called_with(
"get_user_threads",
@@ -3212,6 +3248,7 @@ def test_thread_type(self, thread_type):
"page": 1,
"per_page": 10,
"thread_type": thread_type,
+ "show_deleted": False,
}
if thread_type is None:
@@ -3249,6 +3286,7 @@ def test_flagged(self, flagged_boolean):
"page": 1,
"per_page": 10,
"flagged": flagged_boolean,
+ "show_deleted": False,
}
if flagged_boolean is None:
@@ -3289,6 +3327,7 @@ def test_flagged_count(self, role):
"count_flagged": True,
"page": 1,
"per_page": 10,
+ "show_deleted": False,
}
self.check_mock_called_with(
@@ -3337,6 +3376,7 @@ def test_following(self):
"sort_key": "activity",
"page": 1,
"per_page": 11,
+ "show_deleted": False,
}
self.check_mock_called_with("get_user_subscriptions", -1, **params)
@@ -3364,6 +3404,7 @@ def test_view_query(self, query):
"page": 1,
"per_page": 11,
query: True,
+ "show_deleted": False,
}
self.check_mock_called_with(
"get_user_threads",
@@ -3405,6 +3446,7 @@ def test_order_by_query(self, http_query, cc_query):
"sort_key": cc_query,
"page": 1,
"per_page": 11,
+ "show_deleted": False,
}
self.check_mock_called_with(
"get_user_threads",
@@ -3437,6 +3479,7 @@ def test_order_direction(self):
"sort_key": "activity",
"page": 1,
"per_page": 11,
+ "show_deleted": False,
}
self.check_mock_called_with(
"get_user_threads",
@@ -3711,6 +3754,7 @@ def get_source_and_expected_comments(self):
"votes": {"up_count": 4},
"child_count": 0,
"children": [],
+ "is_deleted": False,
},
{
"type": "comment",
@@ -3728,6 +3772,7 @@ def get_source_and_expected_comments(self):
"votes": {"up_count": 7},
"child_count": 0,
"children": [],
+ "is_deleted": False,
},
]
expected_comments = [
@@ -3737,6 +3782,9 @@ def get_source_and_expected_comments(self):
"parent_id": None,
"author": self.author.username,
"author_label": None,
+ "is_author_banned": False,
+ "author_ban_scope": None,
+ "learner_status": "new",
"created_at": "2015-05-11T00:00:00Z",
"updated_at": "2015-05-11T11:11:11Z",
"raw_body": "Test body",
@@ -3764,6 +3812,10 @@ def get_source_and_expected_comments(self):
"image_url_medium": "http://testserver/static/default_50.png",
"image_url_small": "http://testserver/static/default_30.png",
},
+ "is_deleted": None,
+ "deleted_at": None,
+ "deleted_by": None,
+ "deleted_by_label": None,
},
{
"id": "test_comment_2",
@@ -3771,6 +3823,9 @@ def get_source_and_expected_comments(self):
"parent_id": None,
"author": None,
"author_label": None,
+ "is_author_banned": False,
+ "author_ban_scope": None,
+ "learner_status": "anonymous",
"created_at": "2015-05-11T22:22:22Z",
"updated_at": "2015-05-11T33:33:33Z",
"raw_body": "More content",
@@ -3798,6 +3853,10 @@ def get_source_and_expected_comments(self):
"image_url_medium": "http://testserver/static/default_50.png",
"image_url_small": "http://testserver/static/default_30.png",
},
+ "is_deleted": None,
+ "deleted_at": None,
+ "deleted_by": None,
+ "deleted_by_label": None,
},
]
return source_comments, expected_comments
diff --git a/lms/djangoapps/discussion/rest_api/tests/test_forms.py b/lms/djangoapps/discussion/rest_api/tests/test_forms.py
index 3be65964b6b9..33359337933b 100644
--- a/lms/djangoapps/discussion/rest_api/tests/test_forms.py
+++ b/lms/djangoapps/discussion/rest_api/tests/test_forms.py
@@ -2,7 +2,6 @@
Tests for Discussion API forms
"""
-
import itertools
from unittest import TestCase
from urllib.parse import urlencode
@@ -12,9 +11,9 @@
from opaque_keys.edx.locator import CourseLocator
from lms.djangoapps.discussion.rest_api.forms import (
- UserCommentListGetForm,
CommentListGetForm,
ThreadListGetForm,
+ UserCommentListGetForm,
)
from openedx.core.djangoapps.util.test_forms import FormTestMixin
@@ -36,7 +35,9 @@ def test_missing_page_size(self):
def test_zero_page_size(self):
self.form_data["page_size"] = "0"
- self.assert_error("page_size", "Ensure this value is greater than or equal to 1.")
+ self.assert_error(
+ "page_size", "Ensure this value is greater than or equal to 1."
+ )
def test_excessive_page_size(self):
self.form_data["page_size"] = "101"
@@ -46,6 +47,7 @@ def test_excessive_page_size(self):
@ddt.ddt
class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
"""Tests for ThreadListGetForm"""
+
FORM_CLASS = ThreadListGetForm
def setUp(self):
@@ -58,37 +60,41 @@ def setUp(self):
"page_size": "13",
}
),
- mutable=True
+ mutable=True,
)
def test_basic(self):
form = self.get_form(expected_valid=True)
assert form.cleaned_data == {
- 'course_id': CourseLocator.from_string('Foo/Bar/Baz'),
- 'page': 2,
- 'page_size': 13,
- 'count_flagged': None,
- 'topic_id': set(),
- 'text_search': '',
- 'following': None,
- 'author': '',
- 'thread_type': '',
- 'flagged': None,
- 'view': '',
- 'order_by': 'last_activity_at',
- 'order_direction': 'desc',
- 'requested_fields': set()
+ "course_id": CourseLocator.from_string("Foo/Bar/Baz"),
+ "page": 2,
+ "page_size": 13,
+ "count_flagged": None,
+ "topic_id": set(),
+ "text_search": "",
+ "following": None,
+ "author": "",
+ "thread_type": "",
+ "flagged": None,
+ "show_deleted": None,
+ "view": "",
+ "order_by": "last_activity_at",
+ "order_direction": "desc",
+ "requested_fields": set(),
}
def test_topic_id(self):
self.form_data.setlist("topic_id", ["example topic_id", "example 2nd topic_id"])
form = self.get_form(expected_valid=True)
- assert form.cleaned_data['topic_id'] == {'example topic_id', 'example 2nd topic_id'}
+ assert form.cleaned_data["topic_id"] == {
+ "example topic_id",
+ "example 2nd topic_id",
+ }
def test_text_search(self):
self.form_data["text_search"] = "test search string"
form = self.get_form(expected_valid=True)
- assert form.cleaned_data['text_search'] == 'test search string'
+ assert form.cleaned_data["text_search"] == "test search string"
def test_missing_course_id(self):
self.form_data.pop("course_id")
@@ -109,7 +115,10 @@ def test_thread_type(self, value):
def test_thread_type_invalid(self):
self.form_data["thread_type"] = "invalid-option"
- self.assert_error("thread_type", "Select a valid choice. invalid-option is not one of the available choices.")
+ self.assert_error(
+ "thread_type",
+ "Select a valid choice. invalid-option is not one of the available choices.",
+ )
@ddt.data("True", "true", 1, True)
def test_flagged_true(self, value):
@@ -133,7 +142,9 @@ def test_following_true(self, value):
@ddt.data("False", "false", 0, False)
def test_following_false(self, value):
self.form_data["following"] = value
- self.assert_error("following", "The value of the 'following' parameter must be true.")
+ self.assert_error(
+ "following", "The value of the 'following' parameter must be true."
+ )
def test_invalid_following(self):
self.form_data["following"] = "invalid-boolean"
@@ -144,25 +155,28 @@ def test_mutually_exclusive(self, params):
self.form_data.update({param: "True" for param in params})
self.assert_error(
"__all__",
- "The following query parameters are mutually exclusive: topic_id, text_search, following"
+ "The following query parameters are mutually exclusive: topic_id, text_search, following",
)
def test_invalid_view_choice(self):
self.form_data["view"] = "not_a_valid_choice"
- self.assert_error("view", "Select a valid choice. not_a_valid_choice is not one of the available choices.")
+ self.assert_error(
+ "view",
+ "Select a valid choice. not_a_valid_choice is not one of the available choices.",
+ )
def test_invalid_sort_by_choice(self):
self.form_data["order_by"] = "not_a_valid_choice"
self.assert_error(
"order_by",
- "Select a valid choice. not_a_valid_choice is not one of the available choices."
+ "Select a valid choice. not_a_valid_choice is not one of the available choices.",
)
def test_invalid_sort_direction_choice(self):
self.form_data["order_direction"] = "not_a_valid_choice"
self.assert_error(
"order_direction",
- "Select a valid choice. not_a_valid_choice is not one of the available choices."
+ "Select a valid choice. not_a_valid_choice is not one of the available choices.",
)
@ddt.data(
@@ -181,12 +195,13 @@ def test_valid_choice_fields(self, field, value):
def test_requested_fields(self):
self.form_data["requested_fields"] = "profile_image"
form = self.get_form(expected_valid=True)
- assert form.cleaned_data['requested_fields'] == {'profile_image'}
+ assert form.cleaned_data["requested_fields"] == {"profile_image"}
@ddt.ddt
class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
"""Tests for CommentListGetForm"""
+
FORM_CLASS = CommentListGetForm
def setUp(self):
@@ -202,13 +217,14 @@ def setUp(self):
def test_basic(self):
form = self.get_form(expected_valid=True)
assert form.cleaned_data == {
- 'thread_id': 'deadbeef',
- 'endorsed': False,
- 'page': 2,
- 'page_size': 13,
- 'flagged': False,
- 'requested_fields': set(),
- 'merge_question_type_responses': False
+ "thread_id": "deadbeef",
+ "endorsed": False,
+ "page": 2,
+ "page_size": 13,
+ "flagged": False,
+ "requested_fields": set(),
+ "merge_question_type_responses": False,
+ "show_deleted": None,
}
def test_missing_thread_id(self):
@@ -236,12 +252,13 @@ def test_invalid_endorsed(self):
def test_requested_fields(self):
self.form_data["requested_fields"] = {"profile_image"}
form = self.get_form(expected_valid=True)
- assert form.cleaned_data['requested_fields'] == {'profile_image'}
+ assert form.cleaned_data["requested_fields"] == {"profile_image"}
@ddt.ddt
class UserCommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
"""Tests for UserCommentListGetForm"""
+
FORM_CLASS = UserCommentListGetForm
def setUp(self):
@@ -256,11 +273,11 @@ def setUp(self):
def test_basic(self):
form = self.get_form(expected_valid=True)
assert form.cleaned_data == {
- 'course_id': CourseLocator.from_string('a/b/c'),
- 'flagged': False,
- 'page': 2,
- 'page_size': 13,
- 'requested_fields': set()
+ "course_id": CourseLocator.from_string("a/b/c"),
+ "flagged": False,
+ "page": 2,
+ "page_size": 13,
+ "requested_fields": set(),
}
def test_missing_flagged(self):
@@ -280,7 +297,7 @@ def test_flagged_true(self, value):
def test_requested_fields(self):
self.form_data["requested_fields"] = {"profile_image"}
form = self.get_form(expected_valid=True)
- assert form.cleaned_data['requested_fields'] == {'profile_image'}
+ assert form.cleaned_data["requested_fields"] == {"profile_image"}
def test_missing_course_id(self):
self.form_data.pop("course_id")
diff --git a/lms/djangoapps/discussion/rest_api/tests/test_moderation_emails.py b/lms/djangoapps/discussion/rest_api/tests/test_moderation_emails.py
new file mode 100644
index 000000000000..9256596945e2
--- /dev/null
+++ b/lms/djangoapps/discussion/rest_api/tests/test_moderation_emails.py
@@ -0,0 +1,238 @@
+"""
+Tests for discussion moderation email notifications.
+"""
+from unittest import mock
+from django.test import override_settings
+from django.core import mail
+
+from lms.djangoapps.discussion.rest_api.emails import send_ban_escalation_email
+from common.djangoapps.student.tests.factories import UserFactory
+from xmodule.modulestore.tests.factories import CourseFactory
+from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
+
+
+class BanEscalationEmailTest(ModuleStoreTestCase):
+ """Tests for send_ban_escalation_email function."""
+
+ def setUp(self):
+ super().setUp()
+ self.course = CourseFactory.create(org='TestX', number='CS101', run='2024')
+ self.course_key = str(self.course.id)
+ self.banned_user = UserFactory.create(username='spammer', email='spammer@example.com')
+ self.moderator = UserFactory.create(username='moderator', email='mod@example.com')
+
+ @override_settings(DISCUSSION_MODERATION_BAN_EMAIL_ENABLED=False)
+ def test_email_disabled_by_setting(self):
+ """Test that email is not sent when DISCUSSION_MODERATION_BAN_EMAIL_ENABLED is False."""
+ # Clear outbox
+ mail.outbox = []
+
+ # Try to send email
+ send_ban_escalation_email(
+ banned_user_id=self.banned_user.id,
+ moderator_id=self.moderator.id,
+ course_id=self.course_key,
+ scope='course',
+ reason='Spam',
+ threads_deleted=5,
+ comments_deleted=10
+ )
+
+ # No email should be sent
+ self.assertEqual(len(mail.outbox), 0)
+
+ @override_settings(
+ DISCUSSION_MODERATION_BAN_EMAIL_ENABLED=True,
+ DISCUSSION_MODERATION_ESCALATION_EMAIL='partner-support@edx.org'
+ )
+ @mock.patch('lms.djangoapps.discussion.rest_api.emails.ace')
+ def test_email_sent_via_ace(self, mock_ace_module):
+ """Test that email is sent via ACE when available."""
+ # Create mock ACE send function
+ mock_send = mock.MagicMock()
+ mock_ace_module.send = mock_send
+
+ send_ban_escalation_email(
+ banned_user_id=self.banned_user.id,
+ moderator_id=self.moderator.id,
+ course_id=self.course_key,
+ scope='course',
+ reason='Posting scam links',
+ threads_deleted=3,
+ comments_deleted=7
+ )
+
+ # ACE send should be called
+ mock_send.assert_called_once()
+
+ # Get the message argument
+ call_args = mock_send.call_args
+ message = call_args[0][0]
+
+ # Verify message properties
+ self.assertEqual(message.recipient.email_address, 'partner-support@edx.org')
+ self.assertEqual(message.context['banned_username'], 'spammer')
+ self.assertEqual(message.context['moderator_username'], 'moderator')
+ self.assertEqual(message.context['scope'], 'course')
+ self.assertEqual(message.context['reason'], 'Posting scam links')
+ self.assertEqual(message.context['threads_deleted'], 3)
+ self.assertEqual(message.context['comments_deleted'], 7)
+ self.assertEqual(message.context['total_deleted'], 10)
+
+ @override_settings(
+ DISCUSSION_MODERATION_BAN_EMAIL_ENABLED=True,
+ DISCUSSION_MODERATION_ESCALATION_EMAIL='custom-support@example.com',
+ DEFAULT_FROM_EMAIL='noreply@edx.org'
+ )
+ @mock.patch('lms.djangoapps.discussion.rest_api.emails.ace', None)
+ def test_email_fallback_to_django_mail(self):
+ """Test that email falls back to Django mail when ACE is not available."""
+ # Clear outbox
+ mail.outbox = []
+
+ # Simulate ACE not being importable by making the import fail
+ import sys
+ original_modules = sys.modules.copy()
+
+ # Remove ace modules if present
+ ace_modules = [key for key in sys.modules if key.startswith('edx_ace')]
+ for mod in ace_modules:
+ sys.modules.pop(mod, None)
+
+ try:
+ send_ban_escalation_email(
+ banned_user_id=self.banned_user.id,
+ moderator_id=self.moderator.id,
+ course_id=self.course_key,
+ scope='organization',
+ reason='Multiple violations',
+ threads_deleted=15,
+ comments_deleted=25
+ )
+ finally:
+ # Restore modules
+ sys.modules.update(original_modules)
+
+ # Email should be sent via Django
+ self.assertEqual(len(mail.outbox), 1)
+
+ email = mail.outbox[0]
+ self.assertIn('custom-support@example.com', email.to)
+ self.assertEqual(email.from_email, 'noreply@edx.org')
+ self.assertIn('spammer', email.body)
+ self.assertIn('moderator', email.body)
+ self.assertIn('Multiple violations', email.body)
+ self.assertIn('ORGANIZATION', email.body)
+ self.assertIn('15', email.body) # threads_deleted
+ self.assertIn('25', email.body) # comments_deleted
+
+ @override_settings(
+ DISCUSSION_MODERATION_BAN_EMAIL_ENABLED=True,
+ DISCUSSION_MODERATION_ESCALATION_EMAIL='support@example.com'
+ )
+ @mock.patch('lms.djangoapps.discussion.rest_api.emails.ace', None)
+ def test_email_handles_missing_reason(self):
+ """Test that email handles empty/None reason gracefully."""
+ mail.outbox = []
+
+ # Send with empty reason (will use Django mail since ace is None)
+ send_ban_escalation_email(
+ banned_user_id=self.banned_user.id,
+ moderator_id=self.moderator.id,
+ course_id=self.course_key,
+ scope='course',
+ reason='',
+ threads_deleted=1,
+ comments_deleted=0
+ )
+
+ self.assertEqual(len(mail.outbox), 1)
+ email = mail.outbox[0]
+ # Should use default text
+ self.assertIn('No reason provided', email.body)
+
+ @override_settings(
+ DISCUSSION_MODERATION_BAN_EMAIL_ENABLED=True,
+ DISCUSSION_MODERATION_ESCALATION_EMAIL='support@example.com'
+ )
+ @mock.patch('lms.djangoapps.discussion.rest_api.emails.ace', None)
+ def test_email_with_org_level_ban(self):
+ """Test email for organization-level ban."""
+ mail.outbox = []
+
+ send_ban_escalation_email(
+ banned_user_id=self.banned_user.id,
+ moderator_id=self.moderator.id,
+ course_id=self.course_key,
+ scope='organization',
+ reason='Org-wide spam campaign',
+ threads_deleted=50,
+ comments_deleted=100
+ )
+
+ self.assertEqual(len(mail.outbox), 1)
+ email = mail.outbox[0]
+ self.assertIn('ORGANIZATION', email.body)
+ self.assertIn('Org-wide spam campaign', email.body)
+
+ @override_settings(
+ DISCUSSION_MODERATION_BAN_EMAIL_ENABLED=True,
+ DISCUSSION_MODERATION_ESCALATION_EMAIL='support@example.com'
+ )
+ @mock.patch('lms.djangoapps.discussion.rest_api.emails.ace', None)
+ def test_email_failure_logged(self):
+ """Test that email failures are properly logged."""
+ with mock.patch('django.core.mail.send_mail', side_effect=Exception("SMTP error")):
+ with self.assertLogs('lms.djangoapps.discussion.rest_api.emails', level='ERROR') as logs:
+ with self.assertRaises(Exception):
+ send_ban_escalation_email(
+ banned_user_id=self.banned_user.id,
+ moderator_id=self.moderator.id,
+ course_id=self.course_key,
+ scope='course',
+ reason='Test',
+ threads_deleted=1,
+ comments_deleted=1
+ )
+
+ # Verify error was logged
+ self.assertTrue(any('Failed to send ban escalation email' in log for log in logs.output))
+
+ @override_settings(DISCUSSION_MODERATION_BAN_EMAIL_ENABLED=True)
+ def test_email_with_invalid_user_id(self):
+ """Test that email handles invalid user IDs gracefully."""
+ with self.assertRaises(Exception):
+ send_ban_escalation_email(
+ banned_user_id=99999, # Non-existent user
+ moderator_id=self.moderator.id,
+ course_id=self.course_key,
+ scope='course',
+ reason='Test',
+ threads_deleted=0,
+ comments_deleted=0
+ )
+
+ @override_settings(
+ DISCUSSION_MODERATION_BAN_EMAIL_ENABLED=True,
+ DISCUSSION_MODERATION_ESCALATION_EMAIL='test@example.com'
+ )
+ @mock.patch('lms.djangoapps.discussion.rest_api.emails.ace', None)
+ def test_email_subject_format(self):
+ """Test that email subject is properly formatted."""
+ mail.outbox = []
+
+ send_ban_escalation_email(
+ banned_user_id=self.banned_user.id,
+ moderator_id=self.moderator.id,
+ course_id=self.course_key,
+ scope='course',
+ reason='Test ban',
+ threads_deleted=1,
+ comments_deleted=1
+ )
+
+ self.assertEqual(len(mail.outbox), 1)
+ email = mail.outbox[0]
+ # Subject should contain username and course
+ self.assertIn('spammer', email.subject)
+ self.assertIn(self.course_key, email.subject)
diff --git a/lms/djangoapps/discussion/rest_api/tests/test_moderation_permissions.py b/lms/djangoapps/discussion/rest_api/tests/test_moderation_permissions.py
new file mode 100644
index 000000000000..73981fd5b2cb
--- /dev/null
+++ b/lms/djangoapps/discussion/rest_api/tests/test_moderation_permissions.py
@@ -0,0 +1,203 @@
+"""
+Tests for discussion moderation permissions.
+"""
+from unittest.mock import Mock
+
+from rest_framework.test import APIRequestFactory
+
+from common.djangoapps.student.roles import CourseStaffRole, CourseInstructorRole, GlobalStaff
+from common.djangoapps.student.tests.factories import UserFactory
+from lms.djangoapps.discussion.rest_api.permissions import (
+ IsAllowedToBulkDelete,
+ can_take_action_on_spam,
+)
+from openedx.core.djangoapps.django_comment_common.models import Role
+from xmodule.modulestore.tests.factories import CourseFactory
+from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
+
+
+class CanTakeActionOnSpamTest(ModuleStoreTestCase):
+ """Tests for can_take_action_on_spam permission helper function."""
+
+ def setUp(self):
+ super().setUp()
+ self.course = CourseFactory.create(org='TestX', number='CS101', run='2024')
+ self.course_key = self.course.id
+
+ def test_global_staff_has_permission(self):
+ """Global staff should have permission."""
+ user = UserFactory.create(is_staff=True)
+ self.assertTrue(can_take_action_on_spam(user, self.course_key))
+
+ def test_global_staff_role_has_permission(self):
+ """Users with GlobalStaff role should have permission."""
+ user = UserFactory.create()
+ GlobalStaff().add_users(user)
+ self.assertTrue(can_take_action_on_spam(user, self.course_key))
+
+ def test_course_staff_has_permission(self):
+ """Course staff should have permission for their course."""
+ user = UserFactory.create()
+ CourseStaffRole(self.course_key).add_users(user)
+ self.assertTrue(can_take_action_on_spam(user, self.course_key))
+
+ def test_course_instructor_has_permission(self):
+ """Course instructors should have permission for their course."""
+ user = UserFactory.create()
+ CourseInstructorRole(self.course_key).add_users(user)
+ self.assertTrue(can_take_action_on_spam(user, self.course_key))
+
+ def test_forum_moderator_has_permission(self):
+ """Forum moderators should have permission for their course."""
+ user = UserFactory.create()
+ role = Role.objects.create(name='Moderator', course_id=self.course_key)
+ role.users.add(user)
+ self.assertTrue(can_take_action_on_spam(user, self.course_key))
+
+ def test_forum_administrator_has_permission(self):
+ """Forum administrators should have permission for their course."""
+ user = UserFactory.create()
+ role = Role.objects.create(name='Administrator', course_id=self.course_key)
+ role.users.add(user)
+ self.assertTrue(can_take_action_on_spam(user, self.course_key))
+
+ def test_regular_student_no_permission(self):
+ """Regular students should not have permission."""
+ user = UserFactory.create()
+ self.assertFalse(can_take_action_on_spam(user, self.course_key))
+
+ def test_community_ta_no_permission(self):
+ """Community TAs should not have bulk delete permission."""
+ user = UserFactory.create()
+ role = Role.objects.create(name='Community TA', course_id=self.course_key)
+ role.users.add(user)
+ self.assertFalse(can_take_action_on_spam(user, self.course_key))
+
+ def test_staff_different_course_no_permission(self):
+ """Staff from a different course should not have permission."""
+ other_course = CourseFactory.create(org='OtherX', number='CS201', run='2024')
+ user = UserFactory.create()
+ CourseStaffRole(other_course.id).add_users(user)
+ self.assertFalse(can_take_action_on_spam(user, self.course_key))
+
+ def test_accepts_string_course_id(self):
+ """Function should accept string course_id and convert it."""
+ user = UserFactory.create()
+ CourseStaffRole(self.course_key).add_users(user)
+ self.assertTrue(can_take_action_on_spam(user, str(self.course_key)))
+
+
+class IsAllowedToBulkDeleteTest(ModuleStoreTestCase):
+ """Tests for IsAllowedToBulkDelete permission class."""
+
+ def setUp(self):
+ super().setUp()
+ self.course = CourseFactory.create(org='TestX', number='CS101', run='2024')
+ self.course_key = str(self.course.id)
+ self.factory = APIRequestFactory()
+ self.permission = IsAllowedToBulkDelete()
+
+ def _create_view_with_kwargs(self, course_id=None):
+ """Helper to create a mock view with kwargs."""
+ view = Mock()
+ view.kwargs = {'course_id': course_id} if course_id else {}
+ return view
+
+ def _create_request_with_data(self, user, course_id=None, method='POST'):
+ """Helper to create a request with data."""
+ if method == 'POST':
+ request = self.factory.post('/api/discussion/v1/moderation/bulk-delete-ban/')
+ else:
+ request = self.factory.get('/api/discussion/v1/moderation/banned-users/')
+
+ request.user = user
+ request.data = {'course_id': course_id} if course_id else {}
+ return request
+
+ def test_unauthenticated_user_denied(self):
+ """Unauthenticated users should be denied."""
+ request = self.factory.post('/api/discussion/v1/moderation/bulk-delete-ban/')
+ request.user = Mock(is_authenticated=False)
+ view = self._create_view_with_kwargs()
+
+ self.assertFalse(self.permission.has_permission(request, view))
+
+ def test_global_staff_with_course_id_in_data(self):
+ """Global staff should have permission when course_id is in request data."""
+ user = UserFactory.create(is_staff=True)
+ request = self._create_request_with_data(user, self.course_key)
+ view = self._create_view_with_kwargs()
+
+ self.assertTrue(self.permission.has_permission(request, view))
+
+ def test_course_staff_with_course_id_in_data(self):
+ """Course staff should have permission when course_id is in request data."""
+ user = UserFactory.create()
+ CourseStaffRole(self.course.id).add_users(user)
+ request = self._create_request_with_data(user, self.course_key)
+ view = self._create_view_with_kwargs()
+
+ self.assertTrue(self.permission.has_permission(request, view))
+
+ def test_course_instructor_with_course_id_in_data(self):
+ """Course instructors should have permission when course_id is in request data."""
+ user = UserFactory.create()
+ CourseInstructorRole(self.course.id).add_users(user)
+ request = self._create_request_with_data(user, self.course_key)
+ view = self._create_view_with_kwargs()
+
+ self.assertTrue(self.permission.has_permission(request, view))
+
+ def test_forum_moderator_with_course_id_in_data(self):
+ """Forum moderators should have permission when course_id is in request data."""
+ user = UserFactory.create()
+ role = Role.objects.create(name='Moderator', course_id=self.course.id)
+ role.users.add(user)
+ request = self._create_request_with_data(user, self.course_key)
+ view = self._create_view_with_kwargs()
+
+ self.assertTrue(self.permission.has_permission(request, view))
+
+ def test_regular_student_denied(self):
+ """Regular students should be denied."""
+ user = UserFactory.create()
+ request = self._create_request_with_data(user, self.course_key)
+ view = self._create_view_with_kwargs()
+
+ self.assertFalse(self.permission.has_permission(request, view))
+
+ def test_course_id_in_url_kwargs(self):
+ """Permission should work when course_id is in URL kwargs."""
+ user = UserFactory.create()
+ CourseStaffRole(self.course.id).add_users(user)
+ request = self.factory.get('/api/discussion/v1/moderation/banned-users/')
+ request.user = user
+ request.data = {}
+ request.query_params = {}
+ view = self._create_view_with_kwargs(self.course_key)
+
+ self.assertTrue(self.permission.has_permission(request, view))
+
+ def test_no_course_id_only_global_staff_allowed(self):
+ """When no course_id provided, only global staff should be allowed."""
+ # Global staff allowed
+ global_staff = UserFactory.create(is_staff=True)
+ request = self._create_request_with_data(global_staff)
+ view = self._create_view_with_kwargs()
+ self.assertTrue(self.permission.has_permission(request, view))
+
+ # Regular user denied
+ regular_user = UserFactory.create()
+ request = self._create_request_with_data(regular_user)
+ view = self._create_view_with_kwargs()
+ self.assertFalse(self.permission.has_permission(request, view))
+
+ def test_staff_different_course_denied(self):
+ """Staff from different course should be denied."""
+ other_course = CourseFactory.create(org='OtherX', number='CS201', run='2024')
+ user = UserFactory.create()
+ CourseStaffRole(other_course.id).add_users(user)
+ request = self._create_request_with_data(user, self.course_key)
+ view = self._create_view_with_kwargs()
+
+ self.assertFalse(self.permission.has_permission(request, view))
diff --git a/lms/djangoapps/discussion/rest_api/tests/test_permissions.py b/lms/djangoapps/discussion/rest_api/tests/test_permissions.py
index 405726e2125b..e0a325a3fa3d 100644
--- a/lms/djangoapps/discussion/rest_api/tests/test_permissions.py
+++ b/lms/djangoapps/discussion/rest_api/tests/test_permissions.py
@@ -4,12 +4,16 @@
import itertools
+from unittest.mock import Mock
import ddt
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory
+from common.djangoapps.student.roles import CourseInstructorRole, CourseStaffRole
+from common.djangoapps.student.tests.factories import UserFactory
from lms.djangoapps.discussion.rest_api.permissions import (
+ IsAllowedToRestore,
can_delete,
get_editable_fields,
get_initializable_comment_fields,
@@ -18,6 +22,12 @@
from openedx.core.djangoapps.django_comment_common.comment_client.comment import Comment
from openedx.core.djangoapps.django_comment_common.comment_client.thread import Thread
from openedx.core.djangoapps.django_comment_common.comment_client.user import User
+from openedx.core.djangoapps.django_comment_common.models import (
+ FORUM_ROLE_ADMINISTRATOR,
+ FORUM_ROLE_COMMUNITY_TA,
+ FORUM_ROLE_MODERATOR,
+ Role,
+)
def _get_context(
@@ -202,3 +212,102 @@ def test_comment(self, is_author, is_thread_author, is_privileged):
thread=Thread(user_id="5" if is_thread_author else "6")
)
assert can_delete(comment, context) == (is_author or is_privileged)
+
+
+@ddt.ddt
+class IsAllowedToRestoreTest(ModuleStoreTestCase):
+ """Tests for IsAllowedToRestore permission class"""
+
+ def setUp(self):
+ super().setUp()
+ self.course = CourseFactory.create()
+ self.permission = IsAllowedToRestore()
+
+ def _create_mock_request(self, user, course_id):
+ """Helper to create a mock request object"""
+ request = Mock()
+ request.user = user
+ request.data = {"course_id": str(course_id)}
+ return request
+
+ def _create_mock_view(self):
+ """Helper to create a mock view object"""
+ return Mock()
+
+ def test_unauthenticated_user_denied(self):
+ """Test that unauthenticated users are denied"""
+ user = Mock()
+ user.is_authenticated = False
+ request = self._create_mock_request(user, self.course.id)
+ view = self._create_mock_view()
+
+ assert not self.permission.has_permission(request, view)
+
+ def test_missing_course_id_denied(self):
+ """Test that requests without course_id are denied"""
+ user = UserFactory.create()
+ request = Mock()
+ request.user = user
+ request.data = {} # No course_id
+ view = self._create_mock_view()
+
+ assert not self.permission.has_permission(request, view)
+
+ def test_invalid_course_id_denied(self):
+ """Test that requests with invalid course_id are denied"""
+ user = UserFactory.create()
+ request = Mock()
+ request.user = user
+ request.data = {"course_id": "invalid-course-id"}
+ view = self._create_mock_view()
+
+ assert not self.permission.has_permission(request, view)
+
+ def test_global_staff_allowed(self):
+ """Test that global staff users are allowed"""
+ user = UserFactory.create(is_staff=True)
+ request = self._create_mock_request(user, self.course.id)
+ view = self._create_mock_view()
+
+ assert self.permission.has_permission(request, view)
+
+ def test_course_staff_allowed(self):
+ """Test that course staff are allowed"""
+ user = UserFactory.create()
+ CourseStaffRole(self.course.id).add_users(user)
+ request = self._create_mock_request(user, self.course.id)
+ view = self._create_mock_view()
+
+ assert self.permission.has_permission(request, view)
+
+ def test_course_instructor_allowed(self):
+ """Test that course instructors are allowed"""
+ user = UserFactory.create()
+ CourseInstructorRole(self.course.id).add_users(user)
+ request = self._create_mock_request(user, self.course.id)
+ view = self._create_mock_view()
+
+ assert self.permission.has_permission(request, view)
+
+ @ddt.data(
+ FORUM_ROLE_ADMINISTRATOR,
+ FORUM_ROLE_MODERATOR,
+ FORUM_ROLE_COMMUNITY_TA,
+ )
+ def test_discussion_privileged_users_allowed(self, role_name):
+ """Test that discussion privileged users (moderator, community TA, administrator) are allowed"""
+ user = UserFactory.create()
+ role = Role.objects.get_or_create(name=role_name, course_id=self.course.id)[0]
+ role.users.add(user)
+ request = self._create_mock_request(user, self.course.id)
+ view = self._create_mock_view()
+
+ assert self.permission.has_permission(request, view)
+
+ def test_regular_user_denied(self):
+ """Test that regular users without privileges are denied"""
+ user = UserFactory.create()
+ request = self._create_mock_request(user, self.course.id)
+ view = self._create_mock_view()
+
+ assert not self.permission.has_permission(request, view)
diff --git a/lms/djangoapps/discussion/rest_api/tests/test_serializers.py b/lms/djangoapps/discussion/rest_api/tests/test_serializers.py
index 0cbcc0bebdd1..812bf7a6b9b3 100644
--- a/lms/djangoapps/discussion/rest_api/tests/test_serializers.py
+++ b/lms/djangoapps/discussion/rest_api/tests/test_serializers.py
@@ -9,19 +9,17 @@
import httpretty
from django.test.client import RequestFactory
from django.test.utils import override_settings
-from xmodule.modulestore import ModuleStoreEnum
-from xmodule.modulestore.django import modulestore
-from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
-from xmodule.modulestore.tests.factories import CourseFactory
from common.djangoapps.student.tests.factories import UserFactory
from common.djangoapps.util.testing import UrlResetMixin
-from lms.djangoapps.discussion.django_comment_client.tests.utils import ForumsEnableMixin
+from lms.djangoapps.discussion.django_comment_client.tests.utils import (
+ ForumsEnableMixin,
+)
from lms.djangoapps.discussion.rest_api.serializers import (
CommentSerializer,
ThreadSerializer,
filter_spam_urls_from_html,
- get_context
+ get_context,
)
from lms.djangoapps.discussion.rest_api.tests.utils import (
CommentsServiceMockMixin,
@@ -39,6 +37,10 @@
FORUM_ROLE_STUDENT,
Role,
)
+from xmodule.modulestore import ModuleStoreEnum
+from xmodule.modulestore.django import modulestore
+from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
+from xmodule.modulestore.tests.factories import CourseFactory
@ddt.ddt
@@ -46,13 +48,18 @@ class SerializerTestMixin(ForumsEnableMixin, CommentsServiceMockMixin, UrlResetM
"""
Test Mixin for Serializer tests
"""
+
@classmethod
- @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
+ @mock.patch.dict(
+ "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}
+ )
def setUpClass(cls):
super().setUpClass()
cls.course = CourseFactory.create()
- @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
+ @mock.patch.dict(
+ "django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}
+ )
def setUp(self):
super().setUp()
httpretty.reset()
@@ -60,8 +67,8 @@ def setUp(self):
self.addCleanup(httpretty.reset)
self.addCleanup(httpretty.disable)
patcher = mock.patch(
- 'openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled',
- return_value=False
+ "openedx.core.djangoapps.discussions.config.waffle.ENABLE_FORUM_V2.is_enabled",
+ return_value=False,
)
patcher.start()
self.addCleanup(patcher.stop)
@@ -89,7 +96,9 @@ def create_role(self, role_name, users, course=None):
(FORUM_ROLE_STUDENT, False, True, True),
)
@ddt.unpack
- def test_anonymity(self, role_name, anonymous, anonymous_to_peers, expected_serialized_anonymous):
+ def test_anonymity(
+ self, role_name, anonymous, anonymous_to_peers, expected_serialized_anonymous
+ ):
"""
Test that content is properly made anonymous.
@@ -107,7 +116,9 @@ def test_anonymity(self, role_name, anonymous, anonymous_to_peers, expected_seri
"""
self.create_role(role_name, [self.user])
serialized = self.serialize(
- self.make_cs_content({"anonymous": anonymous, "anonymous_to_peers": anonymous_to_peers})
+ self.make_cs_content(
+ {"anonymous": anonymous, "anonymous_to_peers": anonymous_to_peers}
+ )
)
actual_serialized_anonymous = serialized["author"] is None
assert actual_serialized_anonymous == expected_serialized_anonymous
@@ -138,17 +149,19 @@ def test_author_labels(self, role_name, anonymous, expected_label):
"""
self.create_role(role_name, [self.author])
serialized = self.serialize(self.make_cs_content({"anonymous": anonymous}))
- assert serialized['author_label'] == expected_label
+ assert serialized["author_label"] == expected_label
def test_abuse_flagged(self):
- serialized = self.serialize(self.make_cs_content({"abuse_flaggers": [str(self.user.id)]}))
- assert serialized['abuse_flagged'] is True
+ serialized = self.serialize(
+ self.make_cs_content({"abuse_flaggers": [str(self.user.id)]})
+ )
+ assert serialized["abuse_flagged"] is True
def test_voted(self):
thread_id = "test_thread"
self.register_get_user_response(self.user, upvoted_ids=[thread_id])
serialized = self.serialize(self.make_cs_content({"id": thread_id}))
- assert serialized['voted'] is True
+ assert serialized["voted"] is True
@ddt.ddt
@@ -175,47 +188,62 @@ def serialize(self, thread):
Create a serializer with an appropriate context and use it to serialize
the given thread, returning the result.
"""
- return ThreadSerializer(thread, context=get_context(self.course, self.request)).data
+ return ThreadSerializer(
+ thread, context=get_context(self.course, self.request)
+ ).data
def test_basic(self):
- thread = make_minimal_cs_thread({
- "id": "test_thread",
- "course_id": str(self.course.id),
- "commentable_id": "test_topic",
- "user_id": str(self.author.id),
- "username": self.author.username,
- "title": "Test Title",
- "body": "Test body",
- "pinned": True,
- "votes": {"up_count": 4},
- "comments_count": 5,
- "unread_comments_count": 3,
- })
- expected = self.expected_thread_data({
- "author": self.author.username,
- "can_delete": False,
- "vote_count": 4,
- "comment_count": 6,
- "unread_comment_count": 3,
- "pinned": True,
- "editable_fields": ["abuse_flagged", "copy_link", "following", "read", "voted"],
- "abuse_flagged_count": None,
- "edit_by_label": None,
- "closed_by_label": None,
- })
+ thread = make_minimal_cs_thread(
+ {
+ "id": "test_thread",
+ "course_id": str(self.course.id),
+ "commentable_id": "test_topic",
+ "user_id": str(self.author.id),
+ "username": self.author.username,
+ "title": "Test Title",
+ "body": "Test body",
+ "pinned": True,
+ "votes": {"up_count": 4},
+ "comments_count": 5,
+ "unread_comments_count": 3,
+ }
+ )
+ expected = self.expected_thread_data(
+ {
+ "author": self.author.username,
+ "can_delete": False,
+ "vote_count": 4,
+ "comment_count": 6,
+ "unread_comment_count": 3,
+ "pinned": True,
+ "editable_fields": [
+ "abuse_flagged",
+ "copy_link",
+ "following",
+ "read",
+ "voted",
+ ],
+ "abuse_flagged_count": None,
+ "edit_by_label": None,
+ "closed_by_label": None,
+ "is_deleted": None,
+ }
+ )
assert self.serialize(thread) == expected
thread["thread_type"] = "question"
- expected.update({
- "type": "question",
- "comment_list_url": None,
- "endorsed_comment_list_url": (
- "http://testserver/api/discussion/v1/comments/?thread_id=test_thread&endorsed=True"
- ),
- "non_endorsed_comment_list_url": (
- "http://testserver/api/discussion/v1/comments/?thread_id=test_thread&endorsed=False"
- ),
- })
+ expected.update(
+ {
+ "type": "question",
+ "comment_list_url": None,
+ "endorsed_comment_list_url": (
+ "http://testserver/api/discussion/v1/comments/?thread_id=test_thread&endorsed=True"
+ ),
+ "non_endorsed_comment_list_url": (
+ "http://testserver/api/discussion/v1/comments/?thread_id=test_thread&endorsed=False"
+ ),
+ }
+ )
assert self.serialize(thread) == expected
def test_pinned_missing(self):
@@ -227,34 +255,34 @@ def test_pinned_missing(self):
del thread_data["pinned"]
self.register_get_thread_response(thread_data)
serialized = self.serialize(thread_data)
- assert serialized['pinned'] is False
+ assert serialized["pinned"] is False
def test_group(self):
self.course.cohort_config = {"cohorted": True}
modulestore().update_item(self.course, ModuleStoreEnum.UserID.test)
cohort = CohortFactory.create(course_id=self.course.id)
serialized = self.serialize(self.make_cs_content({"group_id": cohort.id}))
- assert serialized['group_id'] == cohort.id
- assert serialized['group_name'] == cohort.name
+ assert serialized["group_id"] == cohort.id
+ assert serialized["group_name"] == cohort.name
def test_following(self):
thread_id = "test_thread"
self.register_get_user_response(self.user, subscribed_thread_ids=[thread_id])
serialized = self.serialize(self.make_cs_content({"id": thread_id}))
- assert serialized['following'] is True
+ assert serialized["following"] is True
def test_response_count(self):
thread_data = self.make_cs_content({"resp_total": 2})
self.register_get_thread_response(thread_data)
serialized = self.serialize(thread_data)
- assert serialized['response_count'] == 2
+ assert serialized["response_count"] == 2
def test_response_count_missing(self):
thread_data = self.make_cs_content({})
del thread_data["resp_total"]
self.register_get_thread_response(thread_data)
serialized = self.serialize(thread_data)
- assert 'response_count' not in serialized
+ assert "response_count" not in serialized
@ddt.data(
(FORUM_ROLE_MODERATOR, True),
@@ -272,43 +300,62 @@ def test_closed_by_label_field(self, role, visible):
self.create_role(FORUM_ROLE_MODERATOR, [moderator])
self.create_role(request_role, [self.user])
- thread = make_minimal_cs_thread({
- "id": "test_thread",
- "course_id": str(self.course.id),
- "commentable_id": "test_topic",
- "user_id": str(author.id),
- "username": author.username,
- "title": "Test Title",
- "body": "Test body",
- "pinned": True,
- "votes": {"up_count": 4},
- "comments_count": 5,
- "unread_comments_count": 3,
- "closed_by": moderator
- })
+ thread = make_minimal_cs_thread(
+ {
+ "id": "test_thread",
+ "course_id": str(self.course.id),
+ "commentable_id": "test_topic",
+ "user_id": str(author.id),
+ "username": author.username,
+ "title": "Test Title",
+ "body": "Test body",
+ "pinned": True,
+ "votes": {"up_count": 4},
+ "comments_count": 5,
+ "unread_comments_count": 3,
+ "closed_by": moderator,
+ }
+ )
closed_by_label = "Moderator" if visible else None
closed_by = moderator if visible else None
can_delete = role != FORUM_ROLE_STUDENT
editable_fields = ["abuse_flagged", "copy_link", "following", "read", "voted"]
if role == "author":
editable_fields.remove("voted")
- editable_fields.extend(['anonymous', 'raw_body', 'title', 'topic_id', 'type'])
+ editable_fields.extend(
+ ["anonymous", "raw_body", "title", "topic_id", "type"]
+ )
elif role == FORUM_ROLE_MODERATOR:
- editable_fields.extend(['close_reason_code', 'closed', 'edit_reason_code', 'pinned',
- 'raw_body', 'title', 'topic_id', 'type'])
- expected = self.expected_thread_data({
- "author": author.username,
- "can_delete": can_delete,
- "vote_count": 4,
- "comment_count": 6,
- "unread_comment_count": 3,
- "pinned": True,
- "editable_fields": sorted(editable_fields),
- "abuse_flagged_count": None,
- "edit_by_label": None,
- "closed_by_label": closed_by_label,
- "closed_by": closed_by,
- })
+ editable_fields.extend(
+ [
+ "close_reason_code",
+ "closed",
+ "edit_reason_code",
+ "pinned",
+ "raw_body",
+ "title",
+ "topic_id",
+ "type",
+ ]
+ )
+ # is_deleted is visible (False) for privileged users and authors, hidden (None) for others
+ is_deleted = False if role in (FORUM_ROLE_MODERATOR, "author") else None
+ expected = self.expected_thread_data(
+ {
+ "author": author.username,
+ "can_delete": can_delete,
+ "vote_count": 4,
+ "comment_count": 6,
+ "unread_comment_count": 3,
+ "pinned": True,
+ "editable_fields": sorted(editable_fields),
+ "abuse_flagged_count": None,
+ "edit_by_label": None,
+ "closed_by_label": closed_by_label,
+ "closed_by": closed_by,
+ "is_deleted": is_deleted,
+ }
+ )
assert self.serialize(thread) == expected
@ddt.data(
@@ -327,48 +374,69 @@ def test_edit_by_label_field(self, role, visible):
self.create_role(FORUM_ROLE_MODERATOR, [moderator])
self.create_role(request_role, [self.user])
- thread = make_minimal_cs_thread({
- "id": "test_thread",
- "course_id": str(self.course.id),
- "commentable_id": "test_topic",
- "user_id": str(author.id),
- "username": author.username,
- "title": "Test Title",
- "body": "Test body",
- "pinned": True,
- "votes": {"up_count": 4},
- "edit_history": [{"editor_username": moderator}],
- "comments_count": 5,
- "unread_comments_count": 3,
- "closed_by": None
- })
+ thread = make_minimal_cs_thread(
+ {
+ "id": "test_thread",
+ "course_id": str(self.course.id),
+ "commentable_id": "test_topic",
+ "user_id": str(author.id),
+ "username": author.username,
+ "title": "Test Title",
+ "body": "Test body",
+ "pinned": True,
+ "votes": {"up_count": 4},
+ "edit_history": [{"editor_username": moderator}],
+ "comments_count": 5,
+ "unread_comments_count": 3,
+ "closed_by": None,
+ }
+ )
edit_by_label = "Moderator" if visible else None
can_delete = role != FORUM_ROLE_STUDENT
- last_edit = None if role == FORUM_ROLE_STUDENT else {"editor_username": moderator}
+ last_edit = (
+ None if role == FORUM_ROLE_STUDENT else {"editor_username": moderator}
+ )
editable_fields = ["abuse_flagged", "copy_link", "following", "read", "voted"]
if role == "author":
editable_fields.remove("voted")
- editable_fields.extend(['anonymous', 'raw_body', 'title', 'topic_id', 'type'])
+ editable_fields.extend(
+ ["anonymous", "raw_body", "title", "topic_id", "type"]
+ )
elif role == FORUM_ROLE_MODERATOR:
- editable_fields.extend(['close_reason_code', 'closed', 'edit_reason_code', 'pinned',
- 'raw_body', 'title', 'topic_id', 'type'])
+ editable_fields.extend(
+ [
+ "close_reason_code",
+ "closed",
+ "edit_reason_code",
+ "pinned",
+ "raw_body",
+ "title",
+ "topic_id",
+ "type",
+ ]
+ )
- expected = self.expected_thread_data({
- "author": author.username,
- "can_delete": can_delete,
- "vote_count": 4,
- "comment_count": 6,
- "unread_comment_count": 3,
- "pinned": True,
- "editable_fields": sorted(editable_fields),
- "abuse_flagged_count": None,
- "last_edit": last_edit,
- "edit_by_label": edit_by_label,
- "closed_by_label": None,
- "closed_by": None,
- })
+ # is_deleted is visible (False) for privileged users and authors, hidden (None) for others
+ is_deleted = False if role in (FORUM_ROLE_MODERATOR, "author") else None
+ expected = self.expected_thread_data(
+ {
+ "author": author.username,
+ "can_delete": can_delete,
+ "vote_count": 4,
+ "comment_count": 6,
+ "unread_comment_count": 3,
+ "pinned": True,
+ "editable_fields": sorted(editable_fields),
+ "abuse_flagged_count": None,
+ "last_edit": last_edit,
+ "edit_by_label": edit_by_label,
+ "closed_by_label": None,
+ "closed_by": None,
+ "is_deleted": is_deleted,
+ }
+ )
assert self.serialize(thread) == expected
def test_get_preview_body(self):
@@ -384,7 +452,10 @@ def test_get_preview_body(self):
{"body": "
+ {% trans "Action Required:" as action_required %}{{ action_required|force_escape }}
+ {% trans "Please review this moderation action and follow up as needed. If this ban was applied in error or requires adjustment, contact the moderator or course staff." as review_instructions %}{{ review_instructions|force_escape }}
+