diff --git a/forum/__init__.py b/forum/__init__.py index 63ea5f32..fb5dbe93 100644 --- a/forum/__init__.py +++ b/forum/__init__.py @@ -2,4 +2,4 @@ Openedx forum app. """ -__version__ = "0.4.4" +__version__ = "0.4.5" diff --git a/forum/backends/backend.py b/forum/backends/backend.py index c281ace2..d624f240 100644 --- a/forum/backends/backend.py +++ b/forum/backends/backend.py @@ -223,7 +223,9 @@ def get_commentables_counts_based_on_type(course_id: str) -> dict[str, Any]: raise NotImplementedError @classmethod - def get_user_voted_ids(cls, user_id: str, vote: str) -> list[str]: + def get_user_voted_ids( + cls, user_id: str, vote: str, course_id: Optional[str] = None + ) -> list[str]: """Get user voted ids.""" raise NotImplementedError diff --git a/forum/backends/mongodb/api.py b/forum/backends/mongodb/api.py index 3bec9249..dcf18853 100644 --- a/forum/backends/mongodb/api.py +++ b/forum/backends/mongodb/api.py @@ -1159,13 +1159,20 @@ def get_commentables_counts_based_on_type(course_id: str) -> dict[str, Any]: return commentable_counts @classmethod - def get_user_voted_ids(cls, user_id: str, vote: str) -> list[str]: + def get_user_voted_ids( + cls, user_id: str, vote: str, course_id: Optional[str] = None + ) -> list[str]: """Get the IDs of the posts voted by a user.""" if vote not in ["up", "down"]: raise ValueError("Invalid vote type") content_model = Contents() - contents = content_model.get_list() + content_query: dict[str, Any] = {} + if course_id: + content_query["course_id"] = str(course_id) + content_query[f"votes.{vote}"] = {"$in": [user_id, str(user_id)]} + + contents = content_model.get_list(**content_query) voted_ids = [] for content in contents: votes = content["votes"][vote] @@ -1207,8 +1214,12 @@ def user_to_hash( if params.get("complete"): subscribed_thread_ids = cls.find_subscribed_threads(user["external_id"]) - upvoted_ids = cls.get_user_voted_ids(user["external_id"], "up") - downvoted_ids = cls.get_user_voted_ids(user["external_id"], "down") + upvoted_ids = cls.get_user_voted_ids( + user["external_id"], "up", params.get("course_id") + ) + downvoted_ids = cls.get_user_voted_ids( + user["external_id"], "down", params.get("course_id") + ) hash_data.update( { "subscribed_thread_ids": subscribed_thread_ids, diff --git a/forum/backends/mysql/api.py b/forum/backends/mysql/api.py index ca3d8524..51a3342a 100644 --- a/forum/backends/mysql/api.py +++ b/forum/backends/mysql/api.py @@ -1132,7 +1132,9 @@ def get_threads( return threads @classmethod - def get_user_voted_ids(cls, user_id: str, vote: str) -> list[str]: + def get_user_voted_ids( + cls, user_id: str, vote: str, course_id: Optional[str] = None + ) -> list[str]: """Get the IDs of the posts voted by a user.""" if vote not in ["up", "down"]: raise ValueError("Invalid vote type") diff --git a/tests/e2e/test_users.py b/tests/e2e/test_users.py index 893240cf..45323608 100644 --- a/tests/e2e/test_users.py +++ b/tests/e2e/test_users.py @@ -624,7 +624,12 @@ def test_update_user_stats(api_client: APIClient, patched_get_backend: Any) -> N # Sort the data for expected result (threads, responses, replies) expected_result = sorted( expected_data.values(), - key=lambda val: (val["threads"], val["responses"], val["replies"]), + key=lambda val: ( + val["threads"], + val["responses"], + val["replies"], + val["username"], + ), reverse=True, )