Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into sshin/3227
Browse files Browse the repository at this point in the history
  • Loading branch information
suejung-sentry committed Jan 30, 2025
2 parents d163c61 + ae83c2f commit 4978566
Show file tree
Hide file tree
Showing 24 changed files with 584 additions and 55 deletions.
Empty file added api/gen_ai/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions api/gen_ai/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from rest_framework import serializers


class GenAIAuthSerializer(serializers.Serializer):
is_valid = serializers.BooleanField()
114 changes: 114 additions & 0 deletions api/gen_ai/tests/test_gen_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import hmac
from hashlib import sha256
from unittest.mock import patch

from django.urls import reverse
from rest_framework import status
from rest_framework.test import APITestCase
from shared.django_apps.core.tests.factories import OwnerFactory

from codecov_auth.models import GithubAppInstallation

PAYLOAD_SECRET = b"testixik8qdauiab1yiffydimvi72ekq"
VIEW_URL = reverse("auth")


def sign_payload(data: bytes, secret=PAYLOAD_SECRET):
signature = "sha256=" + hmac.new(secret, data, digestmod=sha256).hexdigest()
return signature, data


class GenAIAuthViewTests(APITestCase):
@patch("api.gen_ai.views.get_config", return_value=PAYLOAD_SECRET)
def test_missing_parameters(self, mock_config):
payload = b"{}"
sig, data = sign_payload(payload)
response = self.client.post(
VIEW_URL,
data=data,
content_type="application/json",
HTTP_HTTP_X_GEN_AI_AUTH_SIGNATURE=sig,
)
self.assertEqual(response.status_code, 400)
self.assertIn("Missing required parameters", response.data)

@patch("api.gen_ai.views.get_config", return_value=PAYLOAD_SECRET)
def test_invalid_signature(self, mock_config):
# Correct payload
payload = b'{"external_owner_id":"owner1","repo_service_id":"101"}'
# Wrong signature based on a different payload
wrong_sig = "sha256=" + hmac.new(PAYLOAD_SECRET, b"{}", sha256).hexdigest()
response = self.client.post(
VIEW_URL,
data=payload,
content_type="application/json",
HTTP_HTTP_X_GEN_AI_AUTH_SIGNATURE=wrong_sig,
)
self.assertEqual(response.status_code, 403)

@patch("api.gen_ai.views.get_config", return_value=PAYLOAD_SECRET)
def test_owner_not_found(self, mock_config):
payload = b'{"external_owner_id":"nonexistent_owner","repo_service_id":"101"}'
sig, data = sign_payload(payload)
response = self.client.post(
VIEW_URL,
data=data,
content_type="application/json",
HTTP_HTTP_X_GEN_AI_AUTH_SIGNATURE=sig,
)
self.assertEqual(response.status_code, 404)

@patch("api.gen_ai.views.get_config", return_value=PAYLOAD_SECRET)
def test_no_installation(self, mock_config):
# Create a valid owner but no installation
OwnerFactory(service="github", service_id="owner1", username="test1")
payload = b'{"external_owner_id":"owner1","repo_service_id":"101"}'
sig, data = sign_payload(payload)
response = self.client.post(
VIEW_URL,
data=data,
content_type="application/json",
HTTP_HTTP_X_GEN_AI_AUTH_SIGNATURE=sig,
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, {"is_valid": False})

@patch("api.gen_ai.views.get_config", return_value=PAYLOAD_SECRET)
def test_authorized(self, mock_config):
owner = OwnerFactory(service="github", service_id="owner2", username="test2")
GithubAppInstallation.objects.create(
installation_id=12345,
owner=owner,
name="ai-features",
repository_service_ids=["101", "202"],
)
payload = b'{"external_owner_id":"owner2","repo_service_id":"101"}'
sig, data = sign_payload(payload)
response = self.client.post(
VIEW_URL,
data=data,
content_type="application/json",
HTTP_HTTP_X_GEN_AI_AUTH_SIGNATURE=sig,
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {"is_valid": True})

@patch("api.gen_ai.views.get_config", return_value=PAYLOAD_SECRET)
def test_unauthorized(self, mock_config):
owner = OwnerFactory(service="github", service_id="owner3", username="test3")
GithubAppInstallation.objects.create(
installation_id=2,
owner=owner,
name="ai-features",
repository_service_ids=["303", "404"],
)
payload = b'{"external_owner_id":"owner3","repo_service_id":"101"}'
sig, data = sign_payload(payload)
response = self.client.post(
VIEW_URL,
data=data,
content_type="application/json",
HTTP_HTTP_X_GEN_AI_AUTH_SIGNATURE=sig,
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, {"is_valid": False})
7 changes: 7 additions & 0 deletions api/gen_ai/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from django.urls import path

from .views import GenAIAuthView

urlpatterns = [
path("auth/", GenAIAuthView.as_view(), name="auth"),
]
61 changes: 61 additions & 0 deletions api/gen_ai/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import hmac
import logging
from hashlib import sha256

from rest_framework.exceptions import NotFound, PermissionDenied
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from rest_framework.views import APIView

from api.gen_ai.serializers import GenAIAuthSerializer
from codecov_auth.models import GithubAppInstallation, Owner
from graphql_api.types.owner.owner import AI_FEATURES_GH_APP_ID
from utils.config import get_config

log = logging.getLogger(__name__)


class GenAIAuthView(APIView):
permission_classes = [AllowAny]
serializer_class = GenAIAuthSerializer

def validate_signature(self, request):
key = get_config("gen_ai", "auth_secret")
if not key:
raise PermissionDenied("Invalid signature")

if isinstance(key, str):
key = key.encode("utf-8")
expected_sig = request.headers.get("HTTP-X-GEN-AI-AUTH-SIGNATURE")
computed_sig = (
"sha256=" + hmac.new(key, request.body, digestmod=sha256).hexdigest()
)
if not hmac.compare_digest(computed_sig, expected_sig):
raise PermissionDenied("Invalid signature")

def post(self, request, *args, **kwargs):
self.validate_signature(request)
external_owner_id = request.data.get("external_owner_id")
repo_service_id = request.data.get("repo_service_id")
if not external_owner_id or not repo_service_id:
return Response("Missing required parameters", status=400)
try:
owner = Owner.objects.get(service_id=external_owner_id)
except Owner.DoesNotExist:
raise NotFound("Owner not found")

is_authorized = True

app_install = GithubAppInstallation.objects.filter(
owner_id=owner.ownerid, app_id=AI_FEATURES_GH_APP_ID
).first()

if not app_install:
is_authorized = False

else:
repo_ids = app_install.repository_service_ids
if repo_ids and repo_service_id not in repo_ids:
is_authorized = False

return Response({"is_valid": is_authorized})
1 change: 1 addition & 0 deletions codecov/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@
# /monitoring/metrics will be a public route unless you take steps at a
# higher level to null-route or redirect it.
path("monitoring/", include("django_prometheus.urls")),
path("gen_ai/", include("api.gen_ai.urls")),
]
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@ def execute(self, owner, current_owner):
owner.add_admin(current_owner)
return isAdmin or (current_owner.ownerid in admins)
except Exception as error:
print("Error Calling Admin Provider " + repr(error))
print("Error Calling Admin Provider " + repr(error)) # noqa: T201
return False
50 changes: 40 additions & 10 deletions codecov_auth/commands/owner/interactors/save_terms_agreement.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
@dataclass
class TermsAgreementInput:
business_email: Optional[str] = None
name: Optional[str] = None
terms_agreement: bool = False
marketing_consent: bool = False
customer_intent: Optional[str] = None
Expand All @@ -20,7 +21,7 @@ class TermsAgreementInput:
class SaveTermsAgreementInteractor(BaseInteractor):
requires_service = False

def validate(self, input: TermsAgreementInput) -> None:
def validate_deprecated(self, input: TermsAgreementInput) -> None:
valid_customer_intents = ["Business", "BUSINESS", "Personal", "PERSONAL"]
if (
input.customer_intent
Expand All @@ -30,7 +31,11 @@ def validate(self, input: TermsAgreementInput) -> None:
if not self.current_user.is_authenticated:
raise Unauthenticated()

def update_terms_agreement(self, input: TermsAgreementInput) -> None:
def validate(self, input: TermsAgreementInput) -> None:
if not self.current_user.is_authenticated:
raise Unauthenticated()

def update_terms_agreement_deprecated(self, input: TermsAgreementInput) -> None:
self.current_user.terms_agreement = input.terms_agreement
self.current_user.terms_agreement_at = timezone.now()
self.current_user.customer_intent = input.customer_intent
Expand All @@ -44,6 +49,20 @@ def update_terms_agreement(self, input: TermsAgreementInput) -> None:
if input.marketing_consent:
self.send_data_to_marketo()

def update_terms_agreement(self, input: TermsAgreementInput) -> None:
self.current_user.terms_agreement = input.terms_agreement
self.current_user.terms_agreement_at = timezone.now()
self.current_user.name = input.name
self.current_user.email_opt_in = input.marketing_consent
self.current_user.save()

if input.business_email and input.business_email != "":
self.current_user.email = input.business_email
self.current_user.save()

if input.marketing_consent:
self.send_data_to_marketo()

def send_data_to_marketo(self) -> None:
event_data = {
"email": self.current_user.email,
Expand All @@ -52,11 +71,22 @@ def send_data_to_marketo(self) -> None:

@sync_to_async
def execute(self, input: Any) -> None:
typed_input = TermsAgreementInput(
business_email=input.get("business_email"),
terms_agreement=input.get("terms_agreement"),
marketing_consent=input.get("marketing_consent"),
customer_intent=input.get("customer_intent"),
)
self.validate(typed_input)
return self.update_terms_agreement(typed_input)
if input.get("name"):
typed_input = TermsAgreementInput(
business_email=input.get("business_email"),
terms_agreement=input.get("terms_agreement"),
marketing_consent=input.get("marketing_consent"),
name=input.get("name"),
)
self.validate(typed_input)
self.update_terms_agreement(typed_input)
# this handles the deprecated inputs
else:
typed_input = TermsAgreementInput(
business_email=input.get("business_email"),
terms_agreement=input.get("terms_agreement"),
marketing_consent=input.get("marketing_consent"),
customer_intent=input.get("customer_intent"),
)
self.validate_deprecated(typed_input)
self.update_terms_agreement_deprecated(typed_input)
Loading

0 comments on commit 4978566

Please sign in to comment.