Skip to content

Commit

Permalink
Address pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
probicheaux committed Nov 13, 2024
1 parent fcbf977 commit 917d784
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 20 deletions.
9 changes: 0 additions & 9 deletions inference/core/workflows/core_steps/common/vlms.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,3 @@
"description": "Model returns a JSON response with the specified fields",
},
}


FLORENCE_TASKS_METADATA = {
"unstructured": {
"name": "Unstructured Prompt",
"description": "Use free-form prompt to generate a response. Useful with finetuned models.",
},
**VLM_TASKS_METADATA,
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from inference.core.entities.requests.inference import LMMInferenceRequest
from inference.core.managers.base import ModelManager
from inference.core.workflows.core_steps.common.entities import StepExecutionMode
from inference.core.workflows.core_steps.common.vlms import FLORENCE_TASKS_METADATA
from inference.core.workflows.core_steps.common.vlms import VLM_TASKS_METADATA
from inference.core.workflows.execution_engine.entities.base import (
Batch,
OutputDefinition,
Expand Down Expand Up @@ -37,6 +37,14 @@
T = TypeVar("T")
K = TypeVar("K")

FLORENCE_TASKS_METADATA = {
"custom": {
"name": "Custom Prompt",
"description": "Use free-form prompt to generate a response. Useful with finetuned models.",
},
**VLM_TASKS_METADATA,
}

DETECTIONS_CLASS_NAME_FIELD = "class_name"
DETECTION_ID_FIELD = "detection_id"

Expand Down Expand Up @@ -77,7 +85,7 @@
},
{"task_type": "detection-grounded-ocr", "florence_task": "<REGION_TO_OCR>"},
{"task_type": "region-proposal", "florence_task": "<REGION_PROPOSAL>"},
{"task_type": "unstructured", "florence_task": ""},
{"task_type": "custom", "florence_task": None},
]
TASK_TYPE_TO_FLORENCE_TASK = {
task["task_type"]: task["florence_task"] for task in SUPPORTED_TASK_TYPES_LIST
Expand Down Expand Up @@ -364,6 +372,8 @@ def run_locally(
grounding_selection_mode: GroundingSelectionMode,
) -> BlockResult:
requires_detection_grounding = task_type in TASKS_REQUIRING_DETECTION_GROUNDING

is_not_florence_task = task_type == "custom"
task_type = TASK_TYPE_TO_FLORENCE_TASK[task_type]
inference_images = [
i.to_inference_format(numpy_preferred=False) for i in images
Expand Down Expand Up @@ -391,17 +401,22 @@ def run_locally(
{"raw_output": None, "parsed_output": None, "classes": None}
)
continue
if is_not_florence_task:
prompt = single_prompt or ""
else:
prompt = task_type + (single_prompt or "")

request = LMMInferenceRequest(
api_key=self._api_key,
model_id=model_version,
image=image,
source="workflow-execution",
prompt=task_type + (single_prompt or ""),
prompt=prompt,
)
prediction = self._model_manager.infer_from_request_sync(
model_id=model_version, request=request
)
if task_type == "":
if is_not_florence_task:
prediction_data = prediction.response[
list(prediction.response.keys())[0]
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def run(
grounding_selection_mode: GroundingSelectionMode,
) -> BlockResult:
return super().run(
images,
model_id,
task_type,
prompt,
classes,
grounding_detection,
grounding_selection_mode,
images=images,
model_version=model_id,
task_type=task_type,
prompt=prompt,
classes=classes,
grounding_detection=grounding_detection,
grounding_selection_mode=grounding_selection_mode,
)

0 comments on commit 917d784

Please sign in to comment.