Skip to content

Commit

Permalink
Merge pull request #493 from stacklok/extract-input-code-snippets
Browse files Browse the repository at this point in the history
Extract and process code snippets in the user query
  • Loading branch information
ptelang authored Jan 6, 2025
2 parents 1bb78e2 + b9d8fce commit 8328979
Showing 1 changed file with 32 additions and 10 deletions.
42 changes: 32 additions & 10 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import re

import structlog
from litellm import ChatCompletionRequest
Expand All @@ -9,7 +10,9 @@
PipelineResult,
PipelineStep,
)
from codegate.pipeline.extract_snippets.extract_snippets import extract_snippets
from codegate.storage.storage_engine import StorageEngine
from codegate.utils.package_extractor import PackageExtractor
from codegate.utils.utils import generate_vector_string

logger = structlog.get_logger("codegate")
Expand Down Expand Up @@ -64,26 +67,45 @@ async def process(
if len(user_messages) == 0:
return PipelineResult(request=request)

context_str = "CodeGate did not find any malicious or archived packages."
# Create storage engine object
storage_engine = StorageEngine()

# Extract any code snippets
snippets = extract_snippets(user_messages)

# Collect all packages referenced in the snippets
snippet_packages = []
for snippet in snippets:
snippet_packages.extend(
PackageExtractor.extract_packages(snippet.code, snippet.language)
)
logger.info(f"Found {len(snippet_packages)} packages in code snippets.")

# Find bad packages in the snippets
bad_snippet_packages = await storage_engine.search_by_property("name", snippet_packages)
logger.info(f"Found {len(bad_snippet_packages)} bad packages in code snippets.")

# Remove code snippets from the user messages and search for bad packages
# in the rest of the user query/messsages
user_messages = re.sub(r"```.*?```", "", user_messages, flags=re.DOTALL)

# Vector search to find bad packages
storage_engine = StorageEngine()
searched_objects = await storage_engine.search(query=user_messages, distance=0.8, limit=100)
bad_packages = await storage_engine.search(query=user_messages, distance=0.8, limit=100)

logger.info(
f"Found {len(searched_objects)} matches in the database",
searched_objects=searched_objects,
)
# All bad packages
all_bad_packages = bad_snippet_packages + bad_packages

logger.info(f"Adding {len(all_bad_packages)} bad packages to the context.")

# Generate context string using the searched objects
logger.info(f"Adding {len(searched_objects)} packages to the context")
context_str = "CodeGate did not find any malicious or archived packages."

# Nothing to do if no bad packages are found
if len(searched_objects) == 0:
if len(all_bad_packages) == 0:
return PipelineResult(request=request, context=context)
else:
# Add context for bad packages
context_str = self.generate_context_str(searched_objects, context)
context_str = self.generate_context_str(all_bad_packages, context)
context.bad_packages_found = True

last_user_idx = self.get_last_user_message_idx(request)
Expand Down

0 comments on commit 8328979

Please sign in to comment.