diff --git a/articles/migrations/0008_article_slug.py b/articles/migrations/0008_article_slug.py new file mode 100644 index 0000000000..e37d232cc7 --- /dev/null +++ b/articles/migrations/0008_article_slug.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2.25 on 2025-12-08 09:42 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("articles", "0007_add_editors_group"), + ] + + operations = [ + migrations.AddField( + model_name="article", + name="slug", + field=models.SlugField(blank=True, max_length=255, null=True, unique=True), + ), + ] diff --git a/articles/models.py b/articles/models.py index 56a47ee1b7..ca9faba0db 100644 --- a/articles/models.py +++ b/articles/models.py @@ -2,6 +2,7 @@ from django.conf import settings from django.db import models +from django.utils.text import slugify from main.models import TimestampedModel from profiles.utils import article_image_upload_uri @@ -20,15 +21,38 @@ class Article(TimestampedModel): ) content = models.JSONField(default={}) title = models.CharField(max_length=255) + slug = models.SlugField(max_length=255, unique=True, blank=True, null=True) is_published = models.BooleanField(default=False) + def save(self, *args, **kwargs): + previous = Article.objects.get(pk=self.pk) if self.pk else None + was_published = getattr(previous, "is_published", None) + + # Always initialize slug + slug = self.slug or None + + if not was_published and self.is_published: + max_length = self._meta.get_field("slug").max_length + + base_slug = slugify(self.title)[:max_length] + slug = base_slug + counter = 1 + + # Prevent collisions + while Article.objects.filter(slug=slug).exclude(pk=self.pk).exists(): + suffix = f"-{counter}" + slug = f"{base_slug[: max_length - len(suffix)]}{suffix}" + counter += 1 + + self.slug = slug + super().save(*args, **kwargs) + class ArticleImageUpload(models.Model): user = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE) image_file = models.ImageField( null=True, upload_to=article_image_upload_uri, max_length=2083, editable=False ) - created_at = models.DateTimeField(auto_now_add=True) def __str__(self): diff --git a/articles/serializers.py b/articles/serializers.py index ae3f5296d8..a47279bc5c 100644 --- a/articles/serializers.py +++ b/articles/serializers.py @@ -32,6 +32,7 @@ class RichTextArticleSerializer(serializers.ModelSerializer): created_on = serializers.DateTimeField(read_only=True, required=False) updated_on = serializers.DateTimeField(read_only=True, required=False) content = serializers.JSONField(default={}) + slug = serializers.SlugField(max_length=60, required=False, allow_blank=True) title = serializers.CharField(max_length=255) user = UserSerializer(read_only=True) @@ -45,6 +46,7 @@ class Meta: "created_on", "updated_on", "is_published", + "slug", ] diff --git a/articles/urls.py b/articles/urls.py index 435f4a99cd..d5f0da8ada 100644 --- a/articles/urls.py +++ b/articles/urls.py @@ -1,5 +1,3 @@ -"""URL configuration for staff_content""" - from django.urls import include, path, re_path from rest_framework.routers import SimpleRouter @@ -15,13 +13,16 @@ ) app_name = "articles" + urlpatterns = [ re_path( r"^api/v1/", include( ( [ + # All ViewSet routes *v1_router.urls, + # Media upload endpoint path( "upload-media/", MediaUploadView.as_view(), diff --git a/articles/views.py b/articles/views.py index ae245ddf2b..897aaef386 100644 --- a/articles/views.py +++ b/articles/views.py @@ -1,11 +1,14 @@ from django.conf import settings +from django.shortcuts import get_object_or_404 from django.utils.decorators import method_decorator from drf_spectacular.utils import ( + OpenApiParameter, OpenApiResponse, extend_schema, extend_schema_view, ) from rest_framework import status, viewsets +from rest_framework.decorators import action from rest_framework.pagination import LimitOffsetPagination from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response @@ -77,6 +80,37 @@ def destroy(self, request, *args, **kwargs): clear_views_cache() return super().destroy(request, *args, **kwargs) + @extend_schema( + summary="Retrieve article by ID or slug", + description="If the path parameter is numeric → ID, else → slug.", + parameters=[ + OpenApiParameter( + name="identifier", + type=str, + location=OpenApiParameter.PATH, + description="Article ID (number) or slug (string)", + required=True, + ) + ], + responses={200: RichTextArticleSerializer, 404: OpenApiResponse()}, + ) + @action( + detail=False, + methods=["get"], + url_path="detail/(?P[^/.]+)", + url_name="detail-by-id-or-slug", + ) + def detail_by_id_or_slug(self, _request, identifier): + qs = self.get_queryset() + + if identifier.isdigit(): + article = get_object_or_404(qs, id=int(identifier)) + else: + article = get_object_or_404(qs, slug=identifier) + + serializer = self.get_serializer(article) + return Response(serializer.data, status=status.HTTP_200_OK) + @extend_schema_view( post=extend_schema( diff --git a/articles/views_test.py b/articles/views_test.py index 748eda8a5e..36be0fde0c 100644 --- a/articles/views_test.py +++ b/articles/views_test.py @@ -3,6 +3,7 @@ import pytest from rest_framework.reverse import reverse +from articles.models import Article from main.factories import UserFactory pytestmark = [pytest.mark.django_db] @@ -22,11 +23,69 @@ def test_article_creation(staff_client, user): assert json["title"] == "Some title" -@pytest.mark.parametrize("is_staff", [True, False]) -def test_article_permissions(client, is_staff): - user = UserFactory.create(is_staff=True) - client.force_login(user) - url = reverse("articles:v1:articles-list") +def test_retrieve_article_by_id(client, user): + """Should retrieve published article by numeric ID""" + article = Article.objects.create( + title="Test Article", + content={}, + is_published=True, + user=user, + ) + + url = reverse( + "articles:v1:articles-detail-by-id-or-slug", + kwargs={"identifier": str(article.id)}, + ) + + resp = client.get(url) + data = resp.json() + + assert resp.status_code == 200 + assert data["id"] == article.id + assert data["title"] == "Test Article" + + +def test_retrieve_article_by_slug(client, user): + """Should retrieve published article by slug""" + article = Article.objects.create( + title="Slug Article", + content={}, + is_published=True, + user=user, + ) + + url = reverse( + "articles:v1:articles-detail-by-id-or-slug", + kwargs={"identifier": article.slug}, + ) + resp = client.get(url) - resp.json() - assert resp.status_code == 200 if is_staff else 403 + data = resp.json() + + assert resp.status_code == 200 + assert data["slug"] == article.slug + assert data["title"] == "Slug Article" + + +def test_staff_can_access_unpublished_article(client): + """Staff should be able to see unpublished articles""" + staff_user = UserFactory.create(is_staff=True) + client.force_login(staff_user) + + article = Article.objects.create( + title="Draft Article", + content={}, + is_published=False, + user=staff_user, + ) + + url = reverse( + "articles:v1:articles-detail-by-id-or-slug", + kwargs={"identifier": str(article.id)}, + ) + + resp = client.get(url) + data = resp.json() + + assert resp.status_code == 200 + assert data["id"] == article.id diff --git a/frontends/api/src/generated/v1/api.ts b/frontends/api/src/generated/v1/api.ts index 3b599cd82f..d95b67f368 100644 --- a/frontends/api/src/generated/v1/api.ts +++ b/frontends/api/src/generated/v1/api.ts @@ -4837,6 +4837,12 @@ export interface PatchedRichTextArticleRequest { * @memberof PatchedRichTextArticleRequest */ is_published?: boolean + /** + * + * @type {string} + * @memberof PatchedRichTextArticleRequest + */ + slug?: string } /** * Serializer for UserListRelationship model @@ -7270,6 +7276,12 @@ export interface RichTextArticle { * @memberof RichTextArticle */ is_published?: boolean + /** + * + * @type {string} + * @memberof RichTextArticle + */ + slug?: string } /** * Serializer for LearningResourceInstructor model @@ -7295,6 +7307,12 @@ export interface RichTextArticleRequest { * @memberof RichTextArticleRequest */ is_published?: boolean + /** + * + * @type {string} + * @memberof RichTextArticleRequest + */ + slug?: string } /** * * `phrase` - phrase * `best_fields` - best_fields * `most_fields` - most_fields * `hybrid` - hybrid * `phrase` - phrase * `best_fields` - best_fields * `most_fields` - most_fields * `hybrid` - hybrid @@ -8922,6 +8940,52 @@ export const ArticlesApiAxiosParamCreator = function ( options: localVarRequestOptions, } }, + /** + * If the path parameter is numeric → ID, else → slug. + * @summary Retrieve article by ID or slug + * @param {string} identifier Article ID (number) or slug (string) + * @param {*} [options] Override http request option. + * @throws {RequiredError} + */ + articlesDetailRetrieve: async ( + identifier: string, + options: RawAxiosRequestConfig = {}, + ): Promise => { + // verify required parameter 'identifier' is not null or undefined + assertParamExists("articlesDetailRetrieve", "identifier", identifier) + const localVarPath = `/api/v1/articles/detail/{identifier}/`.replace( + `{${"identifier"}}`, + encodeURIComponent(String(identifier)), + ) + // use dummy base URL string because the URL constructor only accepts absolute URLs. + const localVarUrlObj = new URL(localVarPath, DUMMY_BASE_URL) + let baseOptions + if (configuration) { + baseOptions = configuration.baseOptions + } + + const localVarRequestOptions = { + method: "GET", + ...baseOptions, + ...options, + } + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any + + setSearchParams(localVarUrlObj, localVarQueryParameter) + let headersFromBaseOptions = + baseOptions && baseOptions.headers ? baseOptions.headers : {} + localVarRequestOptions.headers = { + ...localVarHeaderParameter, + ...headersFromBaseOptions, + ...options.headers, + } + + return { + url: toPathString(localVarUrlObj), + options: localVarRequestOptions, + } + }, /** * Get a paginated list of articles * @summary List @@ -9143,6 +9207,38 @@ export const ArticlesApiFp = function (configuration?: Configuration) { configuration, )(axios, operationBasePath || basePath) }, + /** + * If the path parameter is numeric → ID, else → slug. + * @summary Retrieve article by ID or slug + * @param {string} identifier Article ID (number) or slug (string) + * @param {*} [options] Override http request option. + * @throws {RequiredError} + */ + async articlesDetailRetrieve( + identifier: string, + options?: RawAxiosRequestConfig, + ): Promise< + ( + axios?: AxiosInstance, + basePath?: string, + ) => AxiosPromise + > { + const localVarAxiosArgs = + await localVarAxiosParamCreator.articlesDetailRetrieve( + identifier, + options, + ) + const index = configuration?.serverIndex ?? 0 + const operationBasePath = + operationServerMap["ArticlesApi.articlesDetailRetrieve"]?.[index]?.url + return (axios, basePath) => + createRequestFunction( + localVarAxiosArgs, + globalAxios, + BASE_PATH, + configuration, + )(axios, operationBasePath || basePath) + }, /** * Get a paginated list of articles * @summary List @@ -9285,6 +9381,21 @@ export const ArticlesApiFactory = function ( .articlesDestroy(requestParameters.id, options) .then((request) => request(axios, basePath)) }, + /** + * If the path parameter is numeric → ID, else → slug. + * @summary Retrieve article by ID or slug + * @param {ArticlesApiArticlesDetailRetrieveRequest} requestParameters Request parameters. + * @param {*} [options] Override http request option. + * @throws {RequiredError} + */ + articlesDetailRetrieve( + requestParameters: ArticlesApiArticlesDetailRetrieveRequest, + options?: RawAxiosRequestConfig, + ): AxiosPromise { + return localVarFp + .articlesDetailRetrieve(requestParameters.identifier, options) + .then((request) => request(axios, basePath)) + }, /** * Get a paginated list of articles * @summary List @@ -9369,6 +9480,20 @@ export interface ArticlesApiArticlesDestroyRequest { readonly id: number } +/** + * Request parameters for articlesDetailRetrieve operation in ArticlesApi. + * @export + * @interface ArticlesApiArticlesDetailRetrieveRequest + */ +export interface ArticlesApiArticlesDetailRetrieveRequest { + /** + * Article ID (number) or slug (string) + * @type {string} + * @memberof ArticlesApiArticlesDetailRetrieve + */ + readonly identifier: string +} + /** * Request parameters for articlesList operation in ArticlesApi. * @export @@ -9466,6 +9591,23 @@ export class ArticlesApi extends BaseAPI { .then((request) => request(this.axios, this.basePath)) } + /** + * If the path parameter is numeric → ID, else → slug. + * @summary Retrieve article by ID or slug + * @param {ArticlesApiArticlesDetailRetrieveRequest} requestParameters Request parameters. + * @param {*} [options] Override http request option. + * @throws {RequiredError} + * @memberof ArticlesApi + */ + public articlesDetailRetrieve( + requestParameters: ArticlesApiArticlesDetailRetrieveRequest, + options?: RawAxiosRequestConfig, + ) { + return ArticlesApiFp(this.configuration) + .articlesDetailRetrieve(requestParameters.identifier, options) + .then((request) => request(this.axios, this.basePath)) + } + /** * Get a paginated list of articles * @summary List diff --git a/frontends/api/src/hooks/articles/index.ts b/frontends/api/src/hooks/articles/index.ts index ba3d1f7b14..210238a86f 100644 --- a/frontends/api/src/hooks/articles/index.ts +++ b/frontends/api/src/hooks/articles/index.ts @@ -28,6 +28,13 @@ const useArticleDetail = (id: number | undefined) => { }) } +const useArticleDetailRetrieve = (identifier: string | undefined) => { + return useQuery({ + ...articleQueries.articlesDetailRetrieve(identifier ?? ""), + enabled: identifier !== undefined, + }) +} + const useArticleCreate = () => { const client = useQueryClient() return useMutation({ @@ -80,6 +87,10 @@ const useArticlePartialUpdate = () => { .then((response) => response.data), onSuccess: (article: Article) => { client.invalidateQueries({ queryKey: articleKeys.detail(article.id) }) + const identifier = article.slug || article.id.toString() + client.invalidateQueries({ + queryKey: articleKeys.articlesDetailRetrieve(identifier), + }) }, }) } @@ -90,4 +101,5 @@ export { useArticleCreate, useArticleDestroy, useArticlePartialUpdate, + useArticleDetailRetrieve, } diff --git a/frontends/api/src/hooks/articles/queries.ts b/frontends/api/src/hooks/articles/queries.ts index c045dacaff..4261642631 100644 --- a/frontends/api/src/hooks/articles/queries.ts +++ b/frontends/api/src/hooks/articles/queries.ts @@ -8,6 +8,10 @@ const articleKeys = { list: (params: ArticleListRequest) => [...articleKeys.listRoot(), params], detailRoot: () => [...articleKeys.root, "detail"], detail: (id: number) => [...articleKeys.detailRoot(), id], + articlesDetailRetrieve: (identifier: string) => [ + ...articleKeys.detailRoot(), + identifier, + ], } const articleQueries = { @@ -22,6 +26,14 @@ const articleQueries = { queryFn: () => articlesApi.articlesRetrieve({ id }).then((res) => res.data), }), + articlesDetailRetrieve: (identifier: string) => + queryOptions({ + queryKey: articleKeys.articlesDetailRetrieve(identifier), + queryFn: () => + articlesApi + .articlesDetailRetrieve({ identifier }) + .then((res) => res.data), + }), } export { articleQueries, articleKeys } diff --git a/frontends/api/src/test-utils/urls.ts b/frontends/api/src/test-utils/urls.ts index 5b77a2caa5..86473b0440 100644 --- a/frontends/api/src/test-utils/urls.ts +++ b/frontends/api/src/test-utils/urls.ts @@ -153,6 +153,8 @@ const articles = { list: (params?: Params) => `${API_BASE_URL}/api/v1/articles/${query(params)}`, details: (id: number) => `${API_BASE_URL}/api/v1/articles/${id}/`, + articlesDetailRetrieve: (identifier: string) => + `${API_BASE_URL}/api/v1/articles/detail/${identifier}/`, } const userSubscription = { diff --git a/frontends/main/src/app-pages/Articles/ArticleDetailPage.tsx b/frontends/main/src/app-pages/Articles/ArticleDetailPage.tsx index cabd4d98bc..62ae2cf93f 100644 --- a/frontends/main/src/app-pages/Articles/ArticleDetailPage.tsx +++ b/frontends/main/src/app-pages/Articles/ArticleDetailPage.tsx @@ -1,7 +1,7 @@ "use client" import React from "react" -import { useArticleDetail } from "api/hooks/articles" +import { useArticleDetailRetrieve } from "api/hooks/articles" import { LoadingSpinner, ArticleEditor, styled } from "ol-components" import { notFound } from "next/navigation" import { useFeatureFlagEnabled } from "posthog-js/react" @@ -12,12 +12,12 @@ const PageContainer = styled.div({ height: "100%", }) -export const ArticleDetailPage = ({ articleId }: { articleId: number }) => { +export const ArticleDetailPage = ({ articleId }: { articleId: string }) => { const { data: article, isLoading, isFetching, - } = useArticleDetail(Number(articleId)) + } = useArticleDetailRetrieve(articleId) const showArticleDetail = useFeatureFlagEnabled( FeatureFlags.ArticleEditorView, diff --git a/frontends/main/src/app-pages/Articles/ArticleEditPage.test.tsx b/frontends/main/src/app-pages/Articles/ArticleEditPage.test.tsx index 031e456783..1e452b385d 100644 --- a/frontends/main/src/app-pages/Articles/ArticleEditPage.test.tsx +++ b/frontends/main/src/app-pages/Articles/ArticleEditPage.test.tsx @@ -55,7 +55,11 @@ describe.skip("ArticleEditPage", () => { ], }, }) - setMockResponse.get(urls.articles.details(article.id), article) + setMockResponse.get( + urls.articles.articlesDetailRetrieve(String(article.id)), + article, + ) + renderWithProviders() await screen.findByTestId("editor") expect(screen.getByText("Existing Title")).toBeInTheDocument() diff --git a/frontends/main/src/app-pages/Articles/ArticleEditPage.tsx b/frontends/main/src/app-pages/Articles/ArticleEditPage.tsx index 3098b936ae..51ea33871c 100644 --- a/frontends/main/src/app-pages/Articles/ArticleEditPage.tsx +++ b/frontends/main/src/app-pages/Articles/ArticleEditPage.tsx @@ -4,7 +4,7 @@ import React from "react" import { useRouter } from "next-nprogress-bar" import { notFound } from "next/navigation" import { Permission } from "api/hooks/user" -import { useArticleDetail } from "api/hooks/articles" +import { useArticleDetailRetrieve } from "api/hooks/articles" import RestrictedRoute from "@/components/RestrictedRoute/RestrictedRoute" import { styled, LoadingSpinner, ArticleEditor } from "ol-components" import { articlesView } from "@/common/urls" @@ -20,7 +20,7 @@ const ArticleEditPage = ({ articleId }: { articleId: string }) => { data: article, isLoading, isFetching, - } = useArticleDetail(Number(articleId)) + } = useArticleDetailRetrieve(articleId) const router = useRouter() if (isLoading || isFetching) { @@ -36,7 +36,9 @@ const ArticleEditPage = ({ articleId }: { articleId: string }) => { { - router.push(articlesView(article.id)) + if (article.is_published) + return router.push(articlesView(article.slug!)) + router.push(articlesView(String(article.id))) }} /> diff --git a/frontends/main/src/app-pages/Articles/ArticleNewPage.tsx b/frontends/main/src/app-pages/Articles/ArticleNewPage.tsx index 73798a46b0..53c23f288b 100644 --- a/frontends/main/src/app-pages/Articles/ArticleNewPage.tsx +++ b/frontends/main/src/app-pages/Articles/ArticleNewPage.tsx @@ -21,7 +21,9 @@ const ArticleNewPage: React.FC = () => { { - router.push(articlesView(article.id)) + if (article.is_published) + return router.push(articlesView(article.slug!)) + router.push(articlesView(String(article.id))) }} /> diff --git a/frontends/main/src/app/articles/[id]/edit/page.tsx b/frontends/main/src/app/articles/[id]/edit/page.tsx deleted file mode 100644 index 03149399f2..0000000000 --- a/frontends/main/src/app/articles/[id]/edit/page.tsx +++ /dev/null @@ -1,9 +0,0 @@ -import React from "react" -import { ArticleEditPage } from "@/app-pages/Articles/ArticleEditPage" - -const Page: React.FC> = async (props) => { - const params = await props.params - - return -} -export default Page diff --git a/frontends/main/src/app/articles/[slugOrId]/edit/page.tsx b/frontends/main/src/app/articles/[slugOrId]/edit/page.tsx new file mode 100644 index 0000000000..5c69ab1316 --- /dev/null +++ b/frontends/main/src/app/articles/[slugOrId]/edit/page.tsx @@ -0,0 +1,11 @@ +import React from "react" +import { ArticleEditPage } from "@/app-pages/Articles/ArticleEditPage" + +const Page: React.FC> = async ( + props, +) => { + const { slugOrId } = await props.params + + return +} +export default Page diff --git a/frontends/main/src/app/articles/[id]/page.tsx b/frontends/main/src/app/articles/[slugOrId]/page.tsx similarity index 63% rename from frontends/main/src/app/articles/[id]/page.tsx rename to frontends/main/src/app/articles/[slugOrId]/page.tsx index f550dee6ad..cd82ac22cf 100644 --- a/frontends/main/src/app/articles/[id]/page.tsx +++ b/frontends/main/src/app/articles/[slugOrId]/page.tsx @@ -7,9 +7,9 @@ export const metadata: Metadata = standardizeMetadata({ title: "Article Detail", }) -const Page: React.FC> = async (props) => { - const params = await props.params +const Page: React.FC> = async (props) => { + const { slugOrId } = await props.params - return + return } export default Page diff --git a/frontends/main/src/common/urls.ts b/frontends/main/src/common/urls.ts index 3eba0539e7..e60f83fa2b 100644 --- a/frontends/main/src/common/urls.ts +++ b/frontends/main/src/common/urls.ts @@ -27,7 +27,7 @@ export const ARTICLES_LISTING = "/articles/" export const ARTICLES_VIEW = "/articles/[id]" export const ARTICLES_EDIT = "/articles/[id]/edit" export const ARTICLES_CREATE = "/articles/new" -export const articlesView = (id: number) => +export const articlesView = (id: string) => generatePath(ARTICLES_VIEW, { id: String(id) }) export const articlesEditView = (id: number) => generatePath(ARTICLES_EDIT, { id: String(id) }) diff --git a/frontends/ol-components/src/components/TiptapEditor/ArticleEditor.tsx b/frontends/ol-components/src/components/TiptapEditor/ArticleEditor.tsx index e9677bac5d..6b4a63f87c 100644 --- a/frontends/ol-components/src/components/TiptapEditor/ArticleEditor.tsx +++ b/frontends/ol-components/src/components/TiptapEditor/ArticleEditor.tsx @@ -343,7 +343,7 @@ const ArticleEditor = ({ onSave, readOnly, article }: ArticleEditorProps) => { Edit diff --git a/frontends/ol-components/src/components/TiptapEditor/extensions/node/Divider/DividerNode.tsx b/frontends/ol-components/src/components/TiptapEditor/extensions/node/Divider/DividerNode.tsx index d0821ebaea..55e2095e64 100644 --- a/frontends/ol-components/src/components/TiptapEditor/extensions/node/Divider/DividerNode.tsx +++ b/frontends/ol-components/src/components/TiptapEditor/extensions/node/Divider/DividerNode.tsx @@ -55,7 +55,7 @@ export const DividerNode = Node.create({ }, parseHTML() { - return [{ tag: 'div[data-type="divider"]' }, { tag: ". . ." }] + return [{ tag: 'div[data-type="divider"]' }] }, renderHTML({ HTMLAttributes }) { diff --git a/frontends/ol-components/src/components/TiptapEditor/extensions/node/Image/ImageWithCaption.tsx b/frontends/ol-components/src/components/TiptapEditor/extensions/node/Image/ImageWithCaption.tsx index 74f23a6ded..9f4fd1cae4 100644 --- a/frontends/ol-components/src/components/TiptapEditor/extensions/node/Image/ImageWithCaption.tsx +++ b/frontends/ol-components/src/components/TiptapEditor/extensions/node/Image/ImageWithCaption.tsx @@ -1,4 +1,4 @@ -import React from "react" +import React, { useRef, useEffect, useState } from "react" import { NodeViewWrapper } from "@tiptap/react" import type { ReactNodeViewProps } from "@tiptap/react" import styled from "@emotion/styled" @@ -18,7 +18,6 @@ const Container = styled.div({ img: { width: "100%", height: "auto", - aspectRatio: "16/9", borderRadius: "6px", display: "block", }, @@ -69,7 +68,6 @@ const Container = styled.div({ padding: "6px 10px", borderRadius: "8px", gap: "8px", - width: "250px", justifyContent: "center", "&::after": { @@ -116,6 +114,21 @@ const Container = styled.div({ display: "flex", }, }, + + ".svg-icon": { + fill: "white", + }, + ".media-toolbar-wide": { + width: "250px", + }, + ".media-toolbar-default": { + width: "150px", + }, + + ".img-contained": { + width: "auto !important", + margin: "0 auto", + }, }) enum Layout { @@ -145,10 +158,40 @@ export function ImageWithCaption({ node, updateAttributes, }: ReactNodeViewProps) { + const imgRef = useRef(null) + const containerRef = useRef(null) + + const [canExpand, setCanExpand] = useState(true) + const { layout, caption, src, alt } = node.attrs const isEditable = node.attrs.editable + useEffect(() => { + if (!imgRef.current || !containerRef.current) return + + const img = imgRef.current + const container = containerRef.current + + const checkSize = () => { + const containerWidth = container.offsetWidth + const imageNaturalWidth = img.naturalWidth + + // If the image can't expand beyond the container, disable wide/full + setCanExpand(imageNaturalWidth > containerWidth) + } + + // when image loads + if (img.complete) { + checkSize() + } else { + img.onload = checkSize + } + + window.addEventListener("resize", checkSize) + return () => window.removeEventListener("resize", checkSize) + }, [src]) + const openAltTextDialog = async () => { try { const result = await NiceModal.show(ImageAltTextInput, { @@ -161,10 +204,12 @@ export function ImageWithCaption({ } return ( - + {isEditable && ( -
+
- - + {canExpand && ( + <> + {" "} + + + + )}