Skip to content

Commit

Permalink
Merge pull request #547 from jhrozek/unredact_bug
Browse files Browse the repository at this point in the history
Fix pipeline handling in the copilot provider
  • Loading branch information
lukehinds authored Jan 13, 2025
2 parents d7be333 + cb9540a commit c68ef71
Show file tree
Hide file tree
Showing 16 changed files with 156 additions and 160 deletions.
26 changes: 15 additions & 11 deletions src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,13 @@ def __init__(
self.secret_manager = secret_manager
self.is_fim = is_fim
self.context = PipelineContext()

# we create the sesitive context here so that it is not shared between individual requests
# TODO: could we get away with just generating the session ID for an instance?
self.context.sensitive = PipelineSensitiveData(
manager=self.secret_manager,
session_id=str(uuid.uuid4()),
)
self.context.metadata["is_fim"] = is_fim

async def process_request(
Expand All @@ -290,17 +297,14 @@ async def process_request(
is_copilot: bool = False,
) -> PipelineResult:
"""Process a request through all pipeline steps"""
self.context.sensitive = PipelineSensitiveData(
manager=self.secret_manager,
session_id=str(uuid.uuid4()),
api_key=api_key,
model=model,
provider=provider,
api_base=api_base,
)
self.context.metadata["extra_headers"] = extra_headers
current_request = request

self.context.sensitive.api_key = api_key
self.context.sensitive.model = model
self.context.sensitive.provider = provider
self.context.sensitive.api_base = api_base

# For Copilot provider=openai. Use a flag to not clash with other places that may use that.
provider_db = "copilot" if is_copilot else provider

Expand Down Expand Up @@ -336,8 +340,9 @@ def __init__(
self.pipeline_steps = pipeline_steps
self.secret_manager = secret_manager
self.is_fim = is_fim
self.instance = self._create_instance()

def create_instance(self) -> InputPipelineInstance:
def _create_instance(self) -> InputPipelineInstance:
"""Create a new pipeline instance for processing a request"""
return InputPipelineInstance(self.pipeline_steps, self.secret_manager, self.is_fim)

Expand All @@ -352,7 +357,6 @@ async def process_request(
is_copilot: bool = False,
) -> PipelineResult:
"""Create a new pipeline instance and process the request"""
instance = self.create_instance()
return await instance.process_request(
return await self.instance.process_request(
request, provider, model, api_key, api_base, extra_headers, is_copilot
)
4 changes: 4 additions & 0 deletions src/codegate/pipeline/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ async def process_stream(
logger.error(f"Error processing stream: {e}")
raise e
finally:
# Don't flush the buffer if we assume we'll call the pipeline again
if cleanup_sensitive is False:
return

# Process any remaining content in buffer when stream ends
if self._context.buffer:
final_content = "".join(self._context.buffer)
Expand Down
21 changes: 13 additions & 8 deletions src/codegate/pipeline/secrets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class SecretsManager:

def __init__(self):
self.crypto = CodeGateCrypto()
self._session_store: dict[str, SecretEntry] = {}
self._session_store: dict[str, dict[str, SecretEntry]] = {}
self._encrypted_to_session: dict[str, str] = {} # Reverse lookup index

def store_secret(self, value: str, service: str, secret_type: str, session_id: str) -> str:
Expand All @@ -41,12 +41,14 @@ def store_secret(self, value: str, service: str, secret_type: str, session_id: s
encrypted_value = self.crypto.encrypt_token(value, session_id)

# Store mappings
self._session_store[session_id] = SecretEntry(
session_secrets = self._session_store.get(session_id, {})
session_secrets[encrypted_value] = SecretEntry(
original=value,
encrypted=encrypted_value,
service=service,
secret_type=secret_type,
)
self._session_store[session_id] = session_secrets
self._encrypted_to_session[encrypted_value] = session_id

logger.debug("Stored secret", service=service, type=secret_type, encrypted=encrypted_value)
Expand All @@ -58,7 +60,9 @@ def get_original_value(self, encrypted_value: str, session_id: str) -> Optional[
try:
stored_session_id = self._encrypted_to_session.get(encrypted_value)
if stored_session_id == session_id:
return self._session_store[session_id].original
session_secrets = self._session_store[session_id].get(encrypted_value)
if session_secrets:
return session_secrets.original
except Exception as e:
logger.error("Error retrieving secret", error=str(e))
return None
Expand All @@ -71,9 +75,10 @@ def cleanup(self):
"""Securely wipe sensitive data"""
try:
# Convert and wipe original values
for entry in self._session_store.values():
original_bytes = bytearray(entry.original.encode())
self.crypto.wipe_bytearray(original_bytes)
for secrets in self._session_store.values():
for entry in secrets.values():
original_bytes = bytearray(entry.original.encode())
self.crypto.wipe_bytearray(original_bytes)

# Clear the dictionaries
self._session_store.clear()
Expand All @@ -92,9 +97,9 @@ def cleanup_session(self, session_id: str):
"""
try:
# Get the secret entry for the session
entry = self._session_store.get(session_id)
secrets = self._session_store.get(session_id, {})

if entry:
for entry in secrets.values():
# Securely wipe the original value
original_bytes = bytearray(entry.original.encode())
self.crypto.wipe_bytearray(original_bytes)
Expand Down
3 changes: 3 additions & 0 deletions src/codegate/pipeline/secrets/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ async def process_chunk(
if match:
# Found a complete marker, process it
encrypted_value = match.group(1)
print("----> encrypted_value: ", encrypted_value)
original_value = input_context.sensitive.manager.get_original_value(
encrypted_value,
input_context.sensitive.session_id,
Expand All @@ -370,6 +371,8 @@ async def process_chunk(
if original_value is None:
# If value not found, leave as is
original_value = match.group(0) # Keep the REDACTED marker
else:
print("----> original_value: ", original_value)

# Post an alert with the redacted content
input_context.add_alert(self.name, trigger_string=encrypted_value)
Expand Down
14 changes: 3 additions & 11 deletions src/codegate/providers/anthropic/provider.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import json
from typing import Optional

import structlog
from fastapi import Header, HTTPException, Request

from codegate.pipeline.base import SequentialPipelineProcessor
from codegate.pipeline.output import OutputPipelineProcessor
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer
from codegate.providers.anthropic.completion_handler import AnthropicCompletion
from codegate.providers.base import BaseProvider
Expand All @@ -15,20 +13,14 @@
class AnthropicProvider(BaseProvider):
def __init__(
self,
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
pipeline_factory: PipelineFactory,
):
completion_handler = AnthropicCompletion(stream_generator=anthropic_stream_generator)
super().__init__(
AnthropicInputNormalizer(),
AnthropicOutputNormalizer(),
completion_handler,
pipeline_processor,
fim_pipeline_processor,
output_pipeline_processor,
fim_output_pipeline_processor,
pipeline_factory,
)

@property
Expand Down
22 changes: 8 additions & 14 deletions src/codegate/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from codegate.pipeline.base import (
PipelineContext,
PipelineResult,
SequentialPipelineProcessor,
)
from codegate.pipeline.output import OutputPipelineInstance, OutputPipelineProcessor
from codegate.pipeline.factory import PipelineFactory
from codegate.pipeline.output import OutputPipelineInstance
from codegate.providers.completion.base import BaseCompletionHandler
from codegate.providers.formatting.input_pipeline import PipelineResponseFormatter
from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer
Expand All @@ -34,19 +34,13 @@ def __init__(
input_normalizer: ModelInputNormalizer,
output_normalizer: ModelOutputNormalizer,
completion_handler: BaseCompletionHandler,
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
pipeline_factory: PipelineFactory,
):
self.router = APIRouter()
self._completion_handler = completion_handler
self._input_normalizer = input_normalizer
self._output_normalizer = output_normalizer
self._pipeline_processor = pipeline_processor
self._fim_pipelin_processor = fim_pipeline_processor
self._output_pipeline_processor = output_pipeline_processor
self._fim_output_pipeline_processor = fim_output_pipeline_processor
self._pipeline_factory = pipeline_factory
self._db_recorder = DbRecorder()
self._pipeline_response_formatter = PipelineResponseFormatter(
output_normalizer, self._db_recorder
Expand All @@ -73,10 +67,10 @@ async def _run_output_stream_pipeline(
# Decide which pipeline processor to use
out_pipeline_processor = None
if is_fim_request:
out_pipeline_processor = self._fim_output_pipeline_processor
out_pipeline_processor = self._pipeline_factory.create_fim_output_pipeline()
logger.info("FIM pipeline selected for output.")
else:
out_pipeline_processor = self._output_pipeline_processor
out_pipeline_processor = self._pipeline_factory.create_output_pipeline()
logger.info("Chat completion pipeline selected for output.")
if out_pipeline_processor is None:
logger.info("No output pipeline processor found, passing through")
Expand Down Expand Up @@ -117,11 +111,11 @@ async def _run_input_pipeline(
) -> PipelineResult:
# Decide which pipeline processor to use
if is_fim_request:
pipeline_processor = self._fim_pipelin_processor
pipeline_processor = self._pipeline_factory.create_fim_pipeline()
logger.info("FIM pipeline selected for execution.")
normalized_request = self._fim_normalizer.normalize(normalized_request)
else:
pipeline_processor = self._pipeline_processor
pipeline_processor = self._pipeline_factory.create_input_pipeline()
logger.info("Chat completion pipeline selected for execution.")
if pipeline_processor is None:
return PipelineResult(request=normalized_request)
Expand Down
22 changes: 16 additions & 6 deletions src/codegate/providers/copilot/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class CopilotPipeline(ABC):

def __init__(self, pipeline_factory: PipelineFactory):
self.pipeline_factory = pipeline_factory
self.instance = self._create_pipeline()
self.normalizer = self._create_normalizer()
self.provider_name = "openai"

Expand All @@ -33,7 +34,7 @@ def _create_normalizer(self):
pass

@abstractmethod
def create_pipeline(self) -> SequentialPipelineProcessor:
def _create_pipeline(self) -> SequentialPipelineProcessor:
"""Each strategy defines which pipeline to create"""
pass

Expand Down Expand Up @@ -84,7 +85,11 @@ def _create_shortcut_response(result: PipelineResult, model: str) -> bytes:
body = response.model_dump_json(exclude_none=True, exclude_unset=True).encode()
return body

async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, PipelineContext]:
async def process_body(
self,
headers: list[str],
body: bytes,
) -> Tuple[bytes, PipelineContext | None]:
"""Common processing logic for all strategies"""
try:
normalized_body = self.normalizer.normalize(body)
Expand All @@ -97,8 +102,7 @@ async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, Pi
except ValueError:
continue

pipeline = self.create_pipeline()
result = await pipeline.process_request(
result = await self.instance.process_request(
request=normalized_body,
provider=self.provider_name,
model=normalized_body.get("model", "gpt-4o-mini"),
Expand Down Expand Up @@ -168,10 +172,13 @@ class CopilotFimPipeline(CopilotPipeline):
format and the FIM pipeline used by all providers.
"""

def __init__(self, pipeline_factory: PipelineFactory):
super().__init__(pipeline_factory)

def _create_normalizer(self):
return CopilotFimNormalizer()

def create_pipeline(self) -> SequentialPipelineProcessor:
def _create_pipeline(self) -> SequentialPipelineProcessor:
return self.pipeline_factory.create_fim_pipeline()


Expand All @@ -181,8 +188,11 @@ class CopilotChatPipeline(CopilotPipeline):
format and the FIM pipeline used by all providers.
"""

def __init__(self, pipeline_factory: PipelineFactory):
super().__init__(pipeline_factory)

def _create_normalizer(self):
return CopilotChatNormalizer()

def create_pipeline(self) -> SequentialPipelineProcessor:
def _create_pipeline(self) -> SequentialPipelineProcessor:
return self.pipeline_factory.create_input_pipeline()
Loading

0 comments on commit c68ef71

Please sign in to comment.