1616 Dict ,
1717 List ,
1818 Optional ,
19+ Sequence ,
1920 Type ,
2021 Union ,
2122 cast ,
2425 TypedDict ,
2526 overload ,
2627)
27- from collections .abc import AsyncIterator , Iterator , Sequence
28+ from collections .abc import AsyncIterator , Iterator
2829
2930import proto # type: ignore[import-untyped]
3031
7475)
7576from langchain_core .utils .pydantic import is_basemodel_subclass
7677from langchain_core .utils .utils import _build_model_kwargs
77- from vertexai .generative_models import ( # type: ignore
78+ from vertexai .generative_models import (
7879 Tool as VertexTool ,
80+ Candidate as VertexCandidate ,
7981)
80- from vertexai .generative_models ._generative_models import ( # type: ignore
82+ from vertexai .generative_models ._generative_models import (
8183 ToolConfig ,
8284 SafetySettingsType ,
8385 GenerationConfigType ,
8486 GenerationResponse ,
8587 _convert_schema_dict_to_gapic ,
8688)
87- from vertexai .language_models import ( # type: ignore
89+ from vertexai .language_models import (
8890 ChatMessage ,
8991 InputOutputTextPair ,
9092)
@@ -227,10 +229,10 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
227229 if i == 0 and isinstance (message , SystemMessage ):
228230 context = content
229231 elif isinstance (message , AIMessage ):
230- vertex_message = ChatMessage (content = message . content , author = "bot" )
232+ vertex_message = ChatMessage (content = content , author = "bot" )
231233 vertex_messages .append (vertex_message )
232234 elif isinstance (message , HumanMessage ):
233- vertex_message = ChatMessage (content = message . content , author = "user" )
235+ vertex_message = ChatMessage (content = content , author = "user" )
234236 vertex_messages .append (vertex_message )
235237 else :
236238 msg = f"Unexpected message with type { type (message )} at the position { i } ."
@@ -559,16 +561,18 @@ def _parse_examples(examples: List[BaseMessage]) -> List[InputOutputTextPair]:
559561 f"{ type (example )} for the { i } th message."
560562 )
561563 raise ValueError (msg )
562- input_text = example .content
564+ input_text = cast ( "str" , example .content )
563565 if i % 2 == 1 :
564566 if not isinstance (example , AIMessage ):
565567 msg = (
566568 f"Expected the second message in a part to be from AI, got "
567569 f"{ type (example )} for the { i } th message."
568570 )
569571 raise ValueError (msg )
572+ # input_text is guaranteed to be set in the previous iteration
573+ assert input_text is not None
570574 pair = InputOutputTextPair (
571- input_text = input_text , output_text = example .content
575+ input_text = input_text , output_text = cast ( "str" , example .content )
572576 )
573577 example_pairs .append (pair )
574578 return example_pairs
@@ -608,18 +612,19 @@ def _append_to_content(
608612
609613@overload
610614def _parse_response_candidate (
611- response_candidate : Candidate , streaming : Literal [False ] = False
615+ response_candidate : Union [Candidate , VertexCandidate ],
616+ streaming : Literal [False ] = False ,
612617) -> AIMessage : ...
613618
614619
615620@overload
616621def _parse_response_candidate (
617- response_candidate : Candidate , streaming : Literal [True ]
622+ response_candidate : Union [ Candidate , VertexCandidate ] , streaming : Literal [True ]
618623) -> AIMessageChunk : ...
619624
620625
621626def _parse_response_candidate (
622- response_candidate : Candidate , streaming : bool = False
627+ response_candidate : Union [ Candidate , VertexCandidate ] , streaming : bool = False
623628) -> AIMessage :
624629 content : Union [None , str , List [Union [str , dict [str , Any ]]]] = None
625630 additional_kwargs = {}
@@ -635,7 +640,7 @@ def _parse_response_candidate(
635640 except AttributeError :
636641 pass
637642
638- if part .thought :
643+ if hasattr ( part , "thought" ) and part .thought :
639644 thinking_message = {
640645 "type" : "thinking" ,
641646 "thinking" : part .text ,
@@ -694,7 +699,9 @@ def _parse_response_candidate(
694699
695700 if getattr (part , "thought_signature" , None ):
696701 # store dict of {tool_call_id: thought_signature}
697- if isinstance (part .thought_signature , bytes ):
702+ if hasattr (part , "thought_signature" ) and isinstance (
703+ part .thought_signature , bytes
704+ ):
698705 thought_signature = _bytes_to_base64 (part .thought_signature )
699706 if (
700707 _FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY
@@ -1990,7 +1997,7 @@ def _safety_settings_gemini(
19901997 return self ._safety_settings_gemini (self .safety_settings )
19911998 return None
19921999 if isinstance (safety_settings , list ):
1993- return safety_settings
2000+ return cast ( "Sequence[SafetySetting]" , safety_settings )
19942001 if isinstance (safety_settings , dict ):
19952002 formatted_safety_settings = []
19962003 for category , threshold in safety_settings .items ():
@@ -2006,8 +2013,8 @@ def _safety_settings_gemini(
20062013 )
20072014 )
20082015 return formatted_safety_settings
2009- msg = "safety_settings should be either"
2010- raise ValueError (msg )
2016+ # This should be unreachable as all cases are handled above
2017+ raise ValueError ("Unexpected safety_settings type" )
20112018
20122019 def _prepare_request_gemini (
20132020 self ,
@@ -2041,7 +2048,7 @@ def _prepare_request_gemini(
20412048 tool_config = _tool_choice_to_tool_config (tool_choice , all_names )
20422049 else :
20432050 pass
2044- safety_settings = self ._safety_settings_gemini (safety_settings )
2051+ formatted_safety_settings = self ._safety_settings_gemini (safety_settings )
20452052 logprobs = logprobs if logprobs is not None else self .logprobs
20462053 logprobs = logprobs if isinstance (logprobs , (int , bool )) else False
20472054 generation_config = self ._generation_config_gemini (
@@ -2100,18 +2107,22 @@ def _content_to_v1(contents: list[Content]) -> list[v1Content]:
21002107 v1_tools = [v1Tool (** proto .Message .to_dict (t )) for t in formatted_tools ]
21012108
21022109 if tool_config :
2103- v1_tool_config = v1ToolConfig (
2104- function_calling_config = v1FunctionCallingConfig (
2105- ** proto .Message .to_dict (tool_config .function_calling_config )
2110+ v1_tool_config = (
2111+ v1ToolConfig (
2112+ function_calling_config = v1FunctionCallingConfig (
2113+ ** proto .Message .to_dict (tool_config .function_calling_config )
2114+ )
21062115 )
2116+ if hasattr (tool_config , "function_calling_config" )
2117+ else v1ToolConfig ()
21072118 )
21082119
2109- if safety_settings :
2120+ if formatted_safety_settings :
21102121 v1_safety_settings = [
21112122 v1SafetySetting (
21122123 category = s .category , method = s .method , threshold = s .threshold
21132124 )
2114- for s in safety_settings
2125+ for s in formatted_safety_settings
21152126 ]
21162127
21172128 if (self .cached_content is not None ) or (cached_content is not None ):
@@ -2267,7 +2278,7 @@ def _tool_config_gemini(
22672278 self , tool_config : Optional [Union [_ToolConfigDict , ToolConfig ]] = None
22682279 ) -> Optional [GapicToolConfig ]:
22692280 if tool_config and not isinstance (tool_config , ToolConfig ):
2270- return _format_tool_config (cast ( "_ToolConfigDict" , tool_config ) )
2281+ return _format_tool_config (tool_config )
22712282 return None
22722283
22732284 async def _agenerate (
0 commit comments