Skip to content

WIP: Feature: add a qna evaluation for post training checkpoint evals #474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions src/app/api/native/eval/checkpoints/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// app/api/native/eval/checkpoints/route.ts
'use server';

import type { NextRequest } from 'next/server';
import { NextResponse } from 'next/server';

export async function GET(request: NextRequest) {
console.log('Received GET request for /api/native/eval/checkpoints');

try {
const apiUrl = 'http://localhost:8080/checkpoints';
console.log(`Fetching checkpoints from external API: ${apiUrl}`);

const res = await fetch(apiUrl, {
method: 'GET',
headers: {
'Content-Type': 'application/json'
}
});

console.log(`External API response status: ${res.status}`);

if (!res.ok) {
const errorData = await res.json();
console.error('Error from external API:', errorData);
return NextResponse.json({ error: errorData.error || 'Failed to fetch checkpoints' }, { status: res.status });
}

const data = await res.json();
console.log('Checkpoints data fetched successfully:', data);

// Validate that data is an array
if (!Array.isArray(data)) {
console.warn('Unexpected data format from external API:', data);
return NextResponse.json({ error: 'Invalid data format received from checkpoints API.' }, { status: 500 });
}

return NextResponse.json(data);
} catch (error) {
console.error('Error fetching checkpoints:', error);
return NextResponse.json({ error: 'Unable to reach the checkpoints endpoint' }, { status: 500 });
}
}
42 changes: 42 additions & 0 deletions src/app/api/native/eval/qna/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// src/app/api/native/eval/qna/route.ts
import { NextResponse, NextRequest } from 'next/server';

const BACKEND_URL = process.env.BACKEND_URL || 'http://localhost:8080';
const HARD_CODED_QNA_PATH = '/var/home/cloud-user/.local/share/instructlab/taxonomy/knowledge/history/amazon/qna.yaml';

export async function POST(request: NextRequest) {
try {
const body = await request.json();
console.log('[SERVER] Body received:', body);

const { selectedModelDir } = body;
if (!selectedModelDir) {
console.error('[SERVER] Missing selectedModelDir in request body!');
return NextResponse.json({ error: 'Missing required field: selectedModelDir' }, { status: 400 });
}

console.log('[SERVER] selectedModelDir:', selectedModelDir);

// Forward to Go backend
const response = await fetch(`${BACKEND_URL}/qna-eval`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model_path: selectedModelDir,
yaml_file: HARD_CODED_QNA_PATH
})
});

const data = await response.json();
console.log('[SERVER] Response from Go backend:', data);

if (!response.ok) {
return NextResponse.json({ error: data.error || 'Failed to initiate QnA evaluation' }, { status: response.status });
}

return NextResponse.json(data, { status: 200 });
} catch (error) {
console.error('Error in /api/native/eval/qna route:', error);
return NextResponse.json({ error: 'Internal Server Error' }, { status: 500 });
}
}
184 changes: 184 additions & 0 deletions src/components/Dashboard/Native/dashboard.tsx
Original file line number Diff line number Diff line change
@@ -75,6 +75,18 @@ const DashboardNative: React.FunctionComponent = () => {
const [isPublishing, setIsPublishing] = React.useState(false);
const [expandedFiles, setExpandedFiles] = React.useState<Record<string, boolean>>({});

// State Variables for Evaluate Checkpoint
const [isEvalModalOpen, setIsEvalModalOpen] = React.useState<boolean>(false);
const [checkpoints, setCheckpoints] = React.useState<string[]>([]);
const [isCheckpointsLoading, setIsCheckpointsLoading] = React.useState<boolean>(false);
const [checkpointsError, setCheckpointsError] = React.useState<string | null>(null);
const [isDropdownOpen, setIsDropdownOpen] = React.useState<boolean>(false);
const [selectedCheckpoint, setSelectedCheckpoint] = React.useState<string | null>(null);

// QnA eval result
const [qnaEvalResult, setQnaEvalResult] = React.useState<string>('');
const [isQnaLoading, setIsQnaLoading] = React.useState<boolean>(false);

const router = useRouter();

// Fetch branches from the API route
@@ -299,6 +311,100 @@ const DashboardNative: React.FunctionComponent = () => {
}));
};

const handleOpenEvalModal = () => {
setIsEvalModalOpen(true);
fetchCheckpoints();
};

const handleCloseEvalModal = () => {
setIsEvalModalOpen(false);
setCheckpoints([]);
setCheckpointsError(null);
setSelectedCheckpoint(null);
setQnaEvalResult('');
setIsQnaLoading(false);
};

// **New Function to Fetch Checkpoints from API Route**
const fetchCheckpoints = async () => {
setIsCheckpointsLoading(true);
setCheckpointsError(null);
try {
const response = await fetch('/api/native/eval/checkpoints');
console.log('Response status:', response.status);
const data = await response.json();
console.log('Checkpoints data:', data);

if (response.ok) {
// Assuming the API returns an array of checkpoints
if (Array.isArray(data) && data.length > 0) {
setCheckpoints(data);
console.log('Checkpoints set successfully:', data);
} else {
setCheckpoints([]);
console.log('No checkpoints returned from API.');
}
} else {
setCheckpointsError(data.error || 'Failed to fetch checkpoints.');
console.error('Error fetching checkpoints:', data.error || 'Failed to fetch checkpoints.');
}
} catch (error) {
console.error('Error fetching checkpoints:', error);
setCheckpointsError('Unable to reach the checkpoints endpoint.');
} finally {
setIsCheckpointsLoading(false);
}
};

// Checkpoint select dropdown
const onDropdownToggle = (isOpen: boolean) => setIsDropdownOpen(isOpen);
const onSelectCheckpoint = (event: React.MouseEvent<Element, MouseEvent>, selection: string) => {
setSelectedCheckpoint(selection);
setIsDropdownOpen(false);
};

const handleEvaluateQnA = async () => {
if (!selectedCheckpoint) {
addDangerAlert('Please select a checkpoint to evaluate.');
return;
}

setIsQnaLoading(true);
setQnaEvalResult('');

// TODO: dynamically prepend the checkpoint path
const selectedModelDir = '/var/home/cloud-user/.local/share/instructlab/checkpoints/hf_format/' + selectedCheckpoint;

console.log('[CLIENT] Sending to /api/native/eval/qna:', selectedModelDir);

try {
const res = await fetch('/api/native/eval/qna', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ selectedModelDir })
});

const data = await res.json();
console.log('[CLIENT] Response from /api/native/eval/qna:', data);

if (!res.ok) {
addDangerAlert(data.error || 'Failed to evaluate QnA.');
} else {
if (data.result) {
setQnaEvalResult(data.result);
addSuccessAlert('QnA Evaluation succeeded!');
} else {
setQnaEvalResult('Evaluation completed (no result field).');
}
}
} catch (error) {
console.error('Error evaluating QnA:', error);
addDangerAlert('Error evaluating QnA.');
} finally {
setIsQnaLoading(false);
}
};

return (
<div>
<PageBreadcrumb hasBodyWrapper={false}>
@@ -459,6 +565,84 @@ const DashboardNative: React.FunctionComponent = () => {
</PageSection>
)}

{/* Evaluate Checkpoint Modal */}
<Modal
variant={ModalVariant.medium}
title="Evaluate Checkpoint"
isOpen={isEvalModalOpen}
onClose={handleCloseEvalModal}
aria-labelledby="evaluate-checkpoint-modal-title"
aria-describedby="evaluate-checkpoint-modal-body"
>
<ModalHeader title="Evaluate Checkpoint" />
<ModalBody id="evaluate-checkpoint-modal-body">
{isCheckpointsLoading ? (
<Spinner size="lg" aria-label="Loading checkpoints" />
) : checkpointsError ? (
<Alert variant="danger" title={checkpointsError} isInline />
) : (
<>
<div style={{ marginBottom: '1rem' }}>
<label style={{ display: 'block', marginBottom: '0.4rem' }}>Select a Checkpoint:</label>
<Dropdown
isOpen={isDropdownOpen}
onSelect={onSelectCheckpoint}
onOpenChange={onDropdownToggle}
toggle={(toggleRef: React.Ref<MenuToggleElement>) => (
<MenuToggle ref={toggleRef} onClick={() => onDropdownToggle(!isDropdownOpen)} isExpanded={isDropdownOpen}>
{selectedCheckpoint || 'Select a Checkpoint'}
</MenuToggle>
)}
ouiaId="EvaluateCheckpointDropdown"
shouldFocusToggleOnSelect
>
<DropdownList>
{checkpoints.length > 0 ? (
checkpoints.map((checkpoint) => (
<DropdownItem key={checkpoint} value={checkpoint}>
{checkpoint}
</DropdownItem>
))
) : (
<DropdownItem key="no-checkpoints" isDisabled>
No checkpoints available
</DropdownItem>
)}
</DropdownList>
</Dropdown>
</div>

{/* Display the evaluation result */}
{qnaEvalResult && (
<div style={{ marginTop: '1rem' }}>
<b>Evaluation Output:</b>
<pre
style={{
marginTop: '0.5rem',
backgroundColor: '#f5f5f5',
padding: '1rem',
borderRadius: '4px',
maxHeight: '300px',
overflowY: 'auto'
}}
>
{qnaEvalResult}
</pre>
</div>
)}
</>
)}
</ModalBody>
<ModalFooter>
<Button key="evaluateQnA" variant="primary" onClick={handleEvaluateQnA} isDisabled={!selectedCheckpoint || isQnaLoading}>
{isQnaLoading ? 'Evaluating...' : 'Evaluate'}
</Button>
<Button key="cancel" variant="secondary" onClick={handleCloseEvalModal}>
Cancel
</Button>
</ModalFooter>
</Modal>

<Modal
variant={ModalVariant.medium}
title={`Files Contained in Branch: ${diffData?.branch}`}