|
| 1 | +import { useState } from 'react'; |
| 2 | + |
| 3 | +import { $api, fetchClient } from '@geti-inspect/api'; |
| 4 | +import { useProjectIdentifier } from '@geti-inspect/hooks'; |
| 5 | +import { |
| 6 | + Button, |
| 7 | + ButtonGroup, |
| 8 | + Content, |
| 9 | + Dialog, |
| 10 | + Divider, |
| 11 | + Flex, |
| 12 | + Heading, |
| 13 | + Item, |
| 14 | + Picker, |
| 15 | + Text, |
| 16 | + toast, |
| 17 | + type Key, |
| 18 | +} from '@geti/ui'; |
| 19 | +import { useMutation } from '@tanstack/react-query'; |
| 20 | +import type { SchemaCompressionType, SchemaExportType } from 'src/api/openapi-spec'; |
| 21 | +import { Onnx, OpenVino, PyTorch } from 'src/assets/icons'; |
| 22 | + |
| 23 | +import { downloadBlob, sanitizeFilename } from '../utils'; |
| 24 | +import type { ModelData } from './model-types'; |
| 25 | + |
| 26 | +import classes from './export-model-dialog.module.scss'; |
| 27 | + |
| 28 | +const EXPORT_FORMATS: { id: SchemaExportType; name: string; Icon: React.FC<React.SVGProps<SVGSVGElement>> }[] = [ |
| 29 | + { id: 'openvino', name: 'OpenVINO', Icon: OpenVino }, |
| 30 | + { id: 'onnx', name: 'ONNX', Icon: Onnx }, |
| 31 | + { id: 'torch', name: 'PyTorch', Icon: PyTorch }, |
| 32 | +]; |
| 33 | + |
| 34 | +const COMPRESSION_OPTIONS: { id: SchemaCompressionType | 'none'; name: string }[] = [ |
| 35 | + { id: 'none', name: 'None' }, |
| 36 | + { id: 'fp16', name: 'FP16' }, |
| 37 | + { id: 'int8', name: 'INT8' }, |
| 38 | + { id: 'int8_ptq', name: 'INT8 PTQ' }, |
| 39 | + { id: 'int8_acq', name: 'INT8 ACQ' }, |
| 40 | +]; |
| 41 | + |
| 42 | +interface ExportModelDialogProps { |
| 43 | + model: ModelData; |
| 44 | + close: () => void; |
| 45 | +} |
| 46 | + |
| 47 | +export const ExportModelDialog = ({ model, close }: ExportModelDialogProps) => { |
| 48 | + const { projectId } = useProjectIdentifier(); |
| 49 | + const { data: project } = $api.useSuspenseQuery('get', '/api/projects/{project_id}', { |
| 50 | + params: { path: { project_id: projectId } }, |
| 51 | + }); |
| 52 | + const [selectedFormat, setSelectedFormat] = useState<SchemaExportType>('openvino'); |
| 53 | + const [selectedCompression, setSelectedCompression] = useState<SchemaCompressionType | 'none'>('none'); |
| 54 | + |
| 55 | + const exportMutation = useMutation({ |
| 56 | + mutationFn: async () => { |
| 57 | + const compression = selectedCompression === 'none' ? null : selectedCompression; |
| 58 | + |
| 59 | + const response = await fetchClient.POST('/api/projects/{project_id}/models/{model_id}:export', { |
| 60 | + params: { |
| 61 | + path: { |
| 62 | + project_id: projectId, |
| 63 | + model_id: model.id, |
| 64 | + }, |
| 65 | + }, |
| 66 | + body: { |
| 67 | + format: selectedFormat, |
| 68 | + compression, |
| 69 | + }, |
| 70 | + parseAs: 'blob', |
| 71 | + }); |
| 72 | + |
| 73 | + if (response.error) { |
| 74 | + throw new Error('Export failed'); |
| 75 | + } |
| 76 | + |
| 77 | + const blob = response.data as Blob; |
| 78 | + const compressionSuffix = compression ? `_${compression}` : ''; |
| 79 | + const sanitizedProjectName = sanitizeFilename(project.name); |
| 80 | + const sanitizedModelName = sanitizeFilename(model.name); |
| 81 | + const filename = `${sanitizedProjectName}_${sanitizedModelName}_${selectedFormat}${compressionSuffix}.zip`; |
| 82 | + |
| 83 | + return { blob, filename }; |
| 84 | + }, |
| 85 | + onSuccess: ({ blob, filename }) => { |
| 86 | + downloadBlob(blob, filename); |
| 87 | + toast({ type: 'success', message: `Model "${model.name}" exported successfully.` }); |
| 88 | + close(); |
| 89 | + }, |
| 90 | + onError: () => { |
| 91 | + toast({ type: 'error', message: `Failed to export model "${model.name}".` }); |
| 92 | + }, |
| 93 | + }); |
| 94 | + |
| 95 | + const handleFormatChange = (value: string) => { |
| 96 | + const format = value as SchemaExportType; |
| 97 | + setSelectedFormat(format); |
| 98 | + |
| 99 | + if (format !== 'openvino') { |
| 100 | + setSelectedCompression('none'); |
| 101 | + } |
| 102 | + }; |
| 103 | + |
| 104 | + const handleCompressionChange = (key: Key | null) => { |
| 105 | + if (key === null) return; |
| 106 | + setSelectedCompression(key as SchemaCompressionType | 'none'); |
| 107 | + }; |
| 108 | + |
| 109 | + return ( |
| 110 | + <Dialog size='S'> |
| 111 | + <Heading>Export Model</Heading> |
| 112 | + <Divider /> |
| 113 | + <Content> |
| 114 | + <Flex direction='column' gap='size-200'> |
| 115 | + <Text> |
| 116 | + Export <strong>{model.name}</strong> to a downloadable format. |
| 117 | + </Text> |
| 118 | + |
| 119 | + <Flex direction='column' gap='size-100'> |
| 120 | + <Text UNSAFE_className={classes.label}>Export Format</Text> |
| 121 | + <div className={classes.formatGroup} role='radiogroup' aria-label='Select export format'> |
| 122 | + {EXPORT_FORMATS.map(({ id, Icon }) => ( |
| 123 | + <button |
| 124 | + key={id} |
| 125 | + type='button' |
| 126 | + role='radio' |
| 127 | + aria-checked={selectedFormat === id} |
| 128 | + onClick={() => handleFormatChange(id)} |
| 129 | + className={`${classes.formatOption} ${ |
| 130 | + selectedFormat === id ? classes.formatOptionSelected : '' |
| 131 | + }`} |
| 132 | + > |
| 133 | + <Icon className={classes.formatIcon} /> |
| 134 | + </button> |
| 135 | + ))} |
| 136 | + </div> |
| 137 | + </Flex> |
| 138 | + |
| 139 | + {selectedFormat === 'openvino' && ( |
| 140 | + <Picker |
| 141 | + label='Compression (optional)' |
| 142 | + items={COMPRESSION_OPTIONS} |
| 143 | + selectedKey={selectedCompression} |
| 144 | + onSelectionChange={handleCompressionChange} |
| 145 | + width='100%' |
| 146 | + > |
| 147 | + {(item) => <Item key={item.id}>{item.name}</Item>} |
| 148 | + </Picker> |
| 149 | + )} |
| 150 | + </Flex> |
| 151 | + </Content> |
| 152 | + <ButtonGroup> |
| 153 | + <Button variant='secondary' onPress={close} isDisabled={exportMutation.isPending}> |
| 154 | + Cancel |
| 155 | + </Button> |
| 156 | + <Button |
| 157 | + variant='accent' |
| 158 | + onPress={() => exportMutation.mutate()} |
| 159 | + isPending={exportMutation.isPending} |
| 160 | + isDisabled={exportMutation.isPending} |
| 161 | + > |
| 162 | + Export |
| 163 | + </Button> |
| 164 | + </ButtonGroup> |
| 165 | + </Dialog> |
| 166 | + ); |
| 167 | +}; |
0 commit comments