diff --git a/libs/community/langchain_community/document_loaders/confluence.py b/libs/community/langchain_community/document_loaders/confluence.py index ec1de22..662cd97 100644 --- a/libs/community/langchain_community/document_loaders/confluence.py +++ b/libs/community/langchain_community/document_loaders/confluence.py @@ -12,11 +12,112 @@ wait_exponential, ) -from langchain_community.document_loaders.base import BaseLoader +from langchain_community.document_loaders.base import BaseLoader, BaseBlobParser +from langchain_community.document_loaders.blob_loaders import Blob logger = logging.getLogger(__name__) +class SVGParser(BaseBlobParser): + """Parser for SVG blobs.""" + + def lazy_parse(self, blob: Blob) -> Iterator[Document]: + try: + import pytesseract + from PIL import Image + from reportlab.graphics import renderPM + from svglib.svglib import svg2rlg + except ImportError: + raise ImportError( + "`pytesseract`, `Pillow`, `reportlab` or `svglib` package not found, " + "please run `pip install pytesseract Pillow reportlab svglib`" + ) + drawing = svg2rlg(blob.as_bytes()) + img_data = BytesIO() + renderPM.drawToFile(drawing, img_data, fmt="PNG") + img_data.seek(0) + image = Image.open(img_data) + text = pytesseract.image_to_string(image) + yield Document(page_content=text, metadata={"source": blob.source}) + + +class XLSParser(BaseBlobParser): + """Parser for XLS blobs.""" + + def lazy_parse(self, blob: Blob) -> Iterator[Document]: + import io + import os + + try: + import xlrd + except ImportError: + raise ImportError("`xlrd`package not found, please run `pip install xlrd`") + + response = blob.as_bytes() + text = "" + + workbook = xlrd.open_workbook(file_contents=response) + for sheet in workbook.sheets(): + text += f"{sheet.name}:\n" + for row in range(sheet.nrows): + for col in range(sheet.ncols): + text += f"{sheet.cell_value(row, col)}\t" + text += "\n" + text += "\n" + + yield Document(page_content=text, metadata={"source": blob.source}) + + +class Doc2TXTParser(BaseBlobParser): + """Parser for DOCX blobs.""" + + def lazy_parse(self, blob: Blob) -> Iterator[Document]: + try: + import docx2txt + except ImportError: + raise ImportError( + "`docx2txt` package not found, please run `pip install docx2txt`" + ) + yield Document( + page_content=docx2txt.process(BytesIO(blob.as_bytes())), + metadata={"source": blob.source}, + ) + + +def default_parser_factory(attachment_info: dict) -> Optional[BaseBlobParser]: + """Default parser factory for ConfluenceLoader. + + This function takes the attachment information from Confluence and returns + a parser for the attachment. + """ + mime_type = attachment_info["metadata"]["mediaType"] + if mime_type == "application/pdf": + from langchain_community.document_loaders.parsers.pdf import PyMuPDFParser + + return PyMuPDFParser() + elif ( + mime_type == "application/vnd.openxmlformats-officedocument" + ".wordprocessingml.document" + ): + return Doc2TXTParser() + elif ( + media_type == "image/png" + or media_type == "image/jpg" + or media_type == "image/jpeg" + ): + from langchain_community.document_loaders.parsers.images import ( + TesseractBlobParser, + ) + + return TesseractBlobParser() + elif mime_type == "application/vnd.ms-excel": + return XLSParser() + elif mime_type == "image/svg+xml": + return SVGParser() + + return None + + class ContentFormat(str, Enum): """Enumerator of the content formats of Confluence page.""" @@ -123,6 +224,11 @@ class ConfluenceLoader(BaseLoader): :param attachment_filter_func: A function that takes the attachment information from Confluence and decides whether or not the attachment is processed. + :type attachment_filter_func: Callable[[dict], bool], optional + :param attachment_parser_factory: A function that takes the attachment information + from Confluence and returns a parser for the + attachment. + :type attachment_parser_factory: Callable[[dict], Optional[BaseBlobParser]], optional :param include_comments: defaults to False :type include_comments: bool, optional :param content_format: Specify content format, defaults to @@ -180,6 +286,9 @@ def __init__( keep_markdown_format: bool = False, keep_newlines: bool = False, attachment_filter_func: Optional[Callable[[dict], bool]] = None, + attachment_parser_factory: Optional[ + Callable[[dict], Optional[BaseBlobParser]] + ] = default_parser_factory, ): self.space_key = space_key self.page_ids = page_ids @@ -197,6 +306,7 @@ def __init__( self.keep_markdown_format = keep_markdown_format self.keep_newlines = keep_newlines self.attachment_filter_func = attachment_filter_func + self.attachment_parser_factory = attachment_parser_factory confluence_kwargs = confluence_kwargs or {} errors = ConfluenceLoader.validate_init_args( @@ -675,26 +785,32 @@ def process_attachment( absolute_url = self.base_url + attachment["_links"]["download"] title = attachment["title"] try: - if media_type == "application/pdf": - text = title + self.process_pdf(absolute_url, ocr_languages) - elif ( - media_type == "image/png" - or media_type == "image/jpg" - or media_type == "image/jpeg" - ): - text = title + self.process_image(absolute_url, ocr_languages) - elif ( - media_type == "application/vnd.openxmlformats-officedocument" - ".wordprocessingml.document" - ): - text = title + self.process_doc(absolute_url) - elif media_type == "application/vnd.ms-excel": - text = title + self.process_xls(absolute_url) - elif media_type == "image/svg+xml": - text = title + self.process_svg(absolute_url, ocr_languages) - else: - continue - texts.append(text) + if self.attachment_parser_factory: + parser = self.attachment_parser_factory(attachment) + if parser is None: + continue + + response = self.confluence.request(path=absolute_url, absolute=True) + + if ( + response.status_code != 200 + or response.content == b"" + or response.content is None + ): + continue + + blob = Blob( + data=response.content, + mimetype=media_type, + ) + text = ( + title + + " " + + "\n\n".join( + [doc.page_content for doc in parser.lazy_parse(blob)] + ) + ) + texts.append(text) except requests.HTTPError as e: if e.response.status_code == 404: print(f"Attachment not found at {absolute_url}") # noqa: T201 @@ -703,177 +819,3 @@ def process_attachment( raise return texts - - def process_pdf( - self, - link: str, - ocr_languages: Optional[str] = None, - ) -> str: - try: - import pytesseract - from pdf2image import convert_from_bytes - except ImportError: - raise ImportError( - "`pytesseract` or `pdf2image` package not found, " - "please run `pip install pytesseract pdf2image`" - ) - - response = self.confluence.request(path=link, absolute=True) - text = "" - - if ( - response.status_code != 200 - or response.content == b"" - or response.content is None - ): - return text - try: - images = convert_from_bytes(response.content) - except ValueError: - return text - - for i, image in enumerate(images): - try: - image_text = pytesseract.image_to_string(image, lang=ocr_languages) - text += f"Page {i + 1}:\n{image_text}\n\n" - except pytesseract.TesseractError as ex: - logger.warning(f"TesseractError: {ex}") - - return text - - def process_image( - self, - link: str, - ocr_languages: Optional[str] = None, - ) -> str: - try: - import pytesseract - from PIL import Image - except ImportError: - raise ImportError( - "`pytesseract` or `Pillow` package not found, " - "please run `pip install pytesseract Pillow`" - ) - - response = self.confluence.request(path=link, absolute=True) - text = "" - - if ( - response.status_code != 200 - or response.content == b"" - or response.content is None - ): - return text - try: - image = Image.open(BytesIO(response.content)) - except OSError: - return text - - return pytesseract.image_to_string(image, lang=ocr_languages) - - def process_doc(self, link: str) -> str: - try: - import docx2txt - except ImportError: - raise ImportError( - "`docx2txt` package not found, please run `pip install docx2txt`" - ) - - response = self.confluence.request(path=link, absolute=True) - text = "" - - if ( - response.status_code != 200 - or response.content == b"" - or response.content is None - ): - return text - file_data = BytesIO(response.content) - - return docx2txt.process(file_data) - - def process_xls(self, link: str) -> str: - import io - import os - - try: - import xlrd - - except ImportError: - raise ImportError("`xlrd` package not found, please run `pip install xlrd`") - - try: - import pandas as pd - - except ImportError: - raise ImportError( - "`pandas` package not found, please run `pip install pandas`" - ) - - response = self.confluence.request(path=link, absolute=True) - text = "" - - if ( - response.status_code != 200 - or response.content == b"" - or response.content is None - ): - return text - - filename = os.path.basename(link) - # Getting the whole content of the url after filename, - # Example: ".csv?version=2&modificationDate=1631800010678&cacheVersion=1&api=v2" - file_extension = os.path.splitext(filename)[1] - - if file_extension.startswith( - ".csv" - ): # if the extension found in the url is ".csv" - content_string = response.content.decode("utf-8") - df = pd.read_csv(io.StringIO(content_string)) - text += df.to_string(index=False, header=False) + "\n\n" - else: - workbook = xlrd.open_workbook(file_contents=response.content) - for sheet in workbook.sheets(): - text += f"{sheet.name}:\n" - for row in range(sheet.nrows): - for col in range(sheet.ncols): - text += f"{sheet.cell_value(row, col)}\t" - text += "\n" - text += "\n" - - return text - - def process_svg( - self, - link: str, - ocr_languages: Optional[str] = None, - ) -> str: - try: - import pytesseract - from PIL import Image - from reportlab.graphics import renderPM - from svglib.svglib import svg2rlg - except ImportError: - raise ImportError( - "`pytesseract`, `Pillow`, `reportlab` or `svglib` package not found, " - "please run `pip install pytesseract Pillow reportlab svglib`" - ) - - response = self.confluence.request(path=link, absolute=True) - text = "" - - if ( - response.status_code != 200 - or response.content == b"" - or response.content is None - ): - return text - - drawing = svg2rlg(BytesIO(response.content)) - - img_data = BytesIO() - renderPM.drawToFile(drawing, img_data, fmt="PNG") - img_data.seek(0) - image = Image.open(img_data) - - return pytesseract.image_to_string(image, lang=ocr_languages)