Skip to content

Commit 6edb9ab

Browse files
fix other error handling
1 parent a9335dc commit 6edb9ab

10 files changed

Lines changed: 107 additions & 69 deletions

File tree

kfinance/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Changelog
22

33
## v5.1.1
4-
- Update estimates tools to gracefully handle errors
4+
- Update estimates, line item, statements, and segments tools to more gracefully handle errors
55

66
## v5.1.0
77
- Add target price and analyst recommendation tools

kfinance/client/tests/test_fetch.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def test_fetch_estimate(self) -> None:
425425
)
426426
assert isinstance(result, EstimatesResp)
427427
assert result.result is None
428-
assert result.errors == {}
428+
assert result.error is None
429429

430430
def test_fetch_consensus_target_price(self) -> None:
431431
company_id = 21719
@@ -456,7 +456,6 @@ def test_fetch_consensus_target_price(self) -> None:
456456
),
457457
],
458458
),
459-
errors={},
460459
)
461460

462461
result = self.kfinance_api_client.fetch_consensus_target_price(
@@ -492,7 +491,6 @@ def test_fetch_analyst_recommendations(self) -> None:
492491
),
493492
],
494493
),
495-
errors={},
496494
)
497495

498496
result = self.kfinance_api_client.fetch_analyst_recommendations(

kfinance/client/tests/test_objects.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -499,17 +499,17 @@ def fetch_estimates(
499499
period_type,
500500
):
501501
estimates = MOCK_COMPANY_DB[company_id]["estimates"]
502-
return EstimatesResp(result=estimates, errors={})
502+
return EstimatesResp(result=estimates)
503503

504504
def fetch_consensus_target_price(self, company_id):
505505
"""Get consensus target price estimates"""
506506
consensus_target_price = MOCK_COMPANY_DB[company_id]["consensus_target_price"]
507-
return ConsensusTargetPriceResp(result=consensus_target_price, errors={})
507+
return ConsensusTargetPriceResp(result=consensus_target_price)
508508

509509
def fetch_analyst_recommendations(self, company_id):
510510
"""Get analyst recommendations"""
511511
analyst_recommendations = MOCK_COMPANY_DB[company_id]["analyst_recommendations"]
512-
return AnalystRecommendationsResp(result=analyst_recommendations, errors={})
512+
return AnalystRecommendationsResp(result=analyst_recommendations)
513513

514514
def fetch_line_item(
515515
self,

kfinance/domains/estimates/estimates_models.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from datetime import date
22
from decimal import Decimal
33
import logging
4-
from typing import Any
4+
from typing import Any, Callable, Dict
55

6-
from pydantic import BaseModel, model_validator
6+
from pydantic import BaseModel, model_serializer, model_validator
77

88
from kfinance.client.models.date_and_period_models import EstimatePeriodType, EstimateType
9-
from kfinance.client.models.response_models import RespWithErrors
109

1110

1211
logger = logging.getLogger(__name__)
@@ -29,10 +28,20 @@ class Estimates(BaseModel):
2928
periods: dict[str, EstimatesPeriodData]
3029

3130

32-
class EstimatesResp(RespWithErrors):
31+
class EstimatesResp(BaseModel):
3332
"""Response model for a single company's estimates."""
3433

3534
result: Estimates | None = None
35+
error: str | None = None
36+
37+
@model_serializer(mode="wrap")
38+
def serialize_model(self, handler: Callable) -> Dict[str, Any]:
39+
"""Make `error` the last response field and only include if there is at least one error."""
40+
data = handler(self)
41+
error = data.pop("error")
42+
if error:
43+
data["error"] = error
44+
return data
3645

3746
@model_validator(mode="before")
3847
@classmethod
@@ -43,7 +52,11 @@ def from_post_response(cls, data: Any) -> Any:
4352
if len(results) > 1:
4453
logger.warning("Expected at most one result, got %d", len(results))
4554
result = next(iter(results.values()), None)
46-
return {"result": result, "errors": data.get("errors", {})}
55+
errors = data.get("errors", {})
56+
if len(errors) > 1:
57+
logger.warning("Expected at most one error, got %d", len(errors))
58+
error = next(iter(errors.values()), None)
59+
return {"result": result, "error": error}
4760
return data
4861

4962

@@ -58,10 +71,20 @@ class ConsensusTargetPrice(BaseModel):
5871
estimates: list[ConsensusTargetPriceItem]
5972

6073

61-
class ConsensusTargetPriceResp(RespWithErrors):
74+
class ConsensusTargetPriceResp(BaseModel):
6275
"""Response model for a single company's consensus target price."""
6376

6477
result: ConsensusTargetPrice | None = None
78+
error: str | None = None
79+
80+
@model_serializer(mode="wrap")
81+
def serialize_model(self, handler: Callable) -> Dict[str, Any]:
82+
"""Make `error` the last response field and only include if there is at least one error."""
83+
data = handler(self)
84+
error = data.pop("error")
85+
if error:
86+
data["error"] = error
87+
return data
6588

6689
@model_validator(mode="before")
6790
@classmethod
@@ -72,7 +95,11 @@ def from_post_response(cls, data: Any) -> Any:
7295
if len(results) > 1:
7396
logger.warning("Expected at most one result, got %d", len(results))
7497
result = next(iter(results.values()), None)
75-
return {"result": result, "errors": data.get("errors", {})}
98+
errors = data.get("errors", {})
99+
if len(errors) > 1:
100+
logger.warning("Expected at most one error, got %d", len(errors))
101+
error = next(iter(errors.values()), None)
102+
return {"result": result, "error": error}
76103
return data
77104

78105

@@ -86,10 +113,20 @@ class AnalystRecommendations(BaseModel):
86113
estimates: list[AnalystRecommendationsItem]
87114

88115

89-
class AnalystRecommendationsResp(RespWithErrors):
116+
class AnalystRecommendationsResp(BaseModel):
90117
"""Response model for a single company's analyst recommendations."""
91118

92119
result: AnalystRecommendations | None = None
120+
error: str | None = None
121+
122+
@model_serializer(mode="wrap")
123+
def serialize_model(self, handler: Callable) -> Dict[str, Any]:
124+
"""Make `error` the last response field and only include if there is at least one error."""
125+
data = handler(self)
126+
error = data.pop("error")
127+
if error:
128+
data["error"] = error
129+
return data
93130

94131
@model_validator(mode="before")
95132
@classmethod
@@ -100,5 +137,9 @@ def from_post_response(cls, data: Any) -> Any:
100137
if len(results) > 1:
101138
logger.warning("Expected at most one result, got %d", len(results))
102139
result = next(iter(results.values()), None)
103-
return {"result": result, "errors": data.get("errors", {})}
140+
errors = data.get("errors", {})
141+
if len(errors) > 1:
142+
logger.warning("Expected at most one error, got %d", len(errors))
143+
error = next(iter(errors.values()), None)
144+
return {"result": result, "error": error}
104145
return data

kfinance/domains/estimates/estimates_tools.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,9 @@ async def get_estimates_from_identifiers(
226226
resp: EstimatesResp = task.result
227227
if resp.result is not None:
228228
results[task.result_key] = resp.result
229-
errors.extend(resp.errors.values())
229+
if resp.error is not None:
230+
error_msg = f"{task.result_key}: {resp.error}"
231+
errors.append(error_msg)
230232

231233
resp_model = GetEstimatesFromIdentifiersResp(results=results, errors=errors)
232234

@@ -311,7 +313,9 @@ async def get_consensus_target_price_from_identifiers(
311313
resp: ConsensusTargetPriceResp = task.result
312314
if resp.result is not None:
313315
results[task.result_key] = resp.result
314-
errors.extend(resp.errors.values())
316+
if resp.error is not None:
317+
error_msg = f"{task.result_key}: {resp.error}"
318+
errors.append(error_msg)
315319

316320
return GetConsensusTargetPriceFromIdentifiersResp(results=results, errors=errors)
317321

@@ -358,7 +362,9 @@ async def get_analyst_recommendations_from_identifiers(
358362
resp: AnalystRecommendationsResp = task.result
359363
if resp.result is not None:
360364
results[task.result_key] = resp.result
361-
errors.extend(resp.errors.values())
365+
if resp.error is not None:
366+
error_msg = f"{task.result_key}: {resp.error}"
367+
errors.append(error_msg)
362368

363369
return GetAnalystRecommendationsFromIdentifiersResp(results=results, errors=errors)
364370

kfinance/domains/estimates/tests/test_estimate_tools.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ async def test_fetch_estimates_from_company_id(
9898

9999
expected_resp = EstimatesResp(
100100
result=Estimates.model_validate(self.estimates_data),
101-
errors={},
102101
)
103102
assert resp == expected_resp
104103

@@ -158,7 +157,6 @@ async def test_get_estimates_with_guidance_type(
158157

159158
expected_resp = EstimatesResp(
160159
result=Estimates.model_validate(guidance_data),
161-
errors={},
162160
)
163161
assert resp == expected_resp
164162
assert resp.result.estimate_type == EstimateType.guidance
@@ -200,7 +198,6 @@ async def test_fetch_consensus_target_price_from_company_id(
200198

201199
expected_resp = ConsensusTargetPriceResp(
202200
result=ConsensusTargetPrice.model_validate(consensus_target_price_data),
203-
errors={},
204201
)
205202
assert resp == expected_resp
206203

@@ -282,7 +279,6 @@ async def test_fetch_analyst_recommendations_from_company_id(
282279

283280
expected_resp = AnalystRecommendationsResp(
284281
result=AnalystRecommendations.model_validate(analyst_recommendations_data),
285-
errors={},
286282
)
287283
assert resp == expected_resp
288284

kfinance/domains/line_items/line_item_tools.py

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import httpx
66
from pydantic import BaseModel, Field, model_validator
77

8-
from kfinance.async_batch_execution import AsyncTask, batch_execute_async_tasks
98
from kfinance.client.id_resolution import unified_fetch_id_triples
109
from kfinance.client.models.date_and_period_models import NumPeriods, NumPeriodsBack, PeriodType
10+
from kfinance.client.models.response_models import PostResponse
1111
from kfinance.client.permission_models import Permission
1212
from kfinance.domains.line_items.line_item_models import (
1313
LINE_ITEM_NAMES_AND_ALIASES,
@@ -241,40 +241,38 @@ async def get_financial_line_item_from_identifiers(
241241
id_triple.company_id for id_triple in id_triple_resp.identifiers_to_id_triples.values()
242242
]
243243

244-
# Create a single task to fetch line items for all company IDs at once
244+
# Fetch line items for all company IDs at once
245245
if company_ids:
246-
task = AsyncTask(
247-
func=fetch_line_item_from_company_ids,
248-
kwargs=dict(
249-
company_ids=company_ids,
250-
line_item=line_item,
251-
httpx_client=httpx_client,
252-
period_type=period_type,
253-
start_year=start_year,
254-
end_year=end_year,
255-
start_quarter=start_quarter,
256-
end_quarter=end_quarter,
257-
calendar_type=calendar_type,
258-
num_periods=num_periods,
259-
num_periods_back=num_periods_back,
260-
),
261-
result_key="line_items",
246+
line_item_resp = await fetch_line_item_from_company_ids(
247+
company_ids=company_ids,
248+
line_item=line_item,
249+
httpx_client=httpx_client,
250+
period_type=period_type,
251+
start_year=start_year,
252+
end_year=end_year,
253+
start_quarter=start_quarter,
254+
end_quarter=end_quarter,
255+
calendar_type=calendar_type,
256+
num_periods=num_periods,
257+
num_periods_back=num_periods_back,
262258
)
263259

264-
await batch_execute_async_tasks(tasks=[task])
265-
266-
if task.error:
267-
errors.append(task.error)
268-
results = {}
269-
else:
270-
# Map company IDs back to original identifiers
271-
identifier_to_results = {}
272-
for company_id_str, line_item_resp in task.result.items():
273-
original_identifier = id_triple_resp.get_identifier_from_company_id(
274-
int(company_id_str)
275-
)
276-
identifier_to_results[original_identifier] = line_item_resp
277-
results = identifier_to_results
260+
# Add any errors from the line item API, mapping company_id keys back to identifiers
261+
for company_id_str, error in line_item_resp.errors.items():
262+
original_identifier = id_triple_resp.get_identifier_from_company_id(
263+
int(company_id_str)
264+
)
265+
errors.append(f"{original_identifier}: {error}")
266+
267+
# Map results back to original identifiers
268+
identifier_to_results = {}
269+
for company_id_str, line_item_data in line_item_resp.results.items():
270+
original_identifier = id_triple_resp.get_identifier_from_company_id(
271+
int(company_id_str)
272+
)
273+
identifier_to_results[original_identifier] = line_item_data
274+
275+
results = identifier_to_results
278276
else:
279277
results = {}
280278

@@ -316,7 +314,7 @@ async def fetch_line_item_from_company_ids(
316314
calendar_type: CalendarType | None = None,
317315
num_periods: int | None = None,
318316
num_periods_back: int | None = None,
319-
) -> dict[str, LineItemResp]:
317+
) -> PostResponse[LineItemResp]:
320318
"""Fetch line items for a list of company IDs."""
321319
# Build the request payload
322320
params: dict[str, Any] = {
@@ -342,11 +340,5 @@ async def fetch_line_item_from_company_ids(
342340
params["num_periods_back"] = num_periods_back
343341

344342
resp = await httpx_client.post(url="/line_item/", json=params)
345-
response_data = resp.json()
346-
347-
# Convert the response data to LineItemResp objects
348-
results = {}
349-
for company_id_str, line_item_data in response_data["results"].items():
350-
results[company_id_str] = LineItemResp.model_validate(line_item_data)
351343

352-
return results
344+
return PostResponse[LineItemResp].model_validate(resp.json())

kfinance/domains/line_items/tests/test_line_item_tools.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,11 @@ async def test_fetch_line_item_from_company_ids(
7070
httpx_client=httpx_client,
7171
)
7272

73-
expected_resp = {
73+
expected_results = {
7474
str(SPGI_ID_TRIPLE.company_id): LineItemResp.model_validate(self.line_item_resp)
7575
}
76-
assert resp == expected_resp
76+
assert resp.results == expected_results
77+
assert resp.errors == {}
7778

7879
@pytest.mark.parametrize(
7980
"calendar_type, expected_notes",

kfinance/domains/segments/segment_tools.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,10 @@ async def get_segments_from_identifiers(
136136
httpx_client=httpx_client,
137137
)
138138

139-
# Add any errors from the segments API
140-
errors.extend(segments_resp.errors.values())
139+
# Add any errors from the segments API, mapping company_id keys back to identifiers
140+
for company_id_str, error in segments_resp.errors.items():
141+
original_identifier = id_triple_resp.get_identifier_from_company_id(int(company_id_str))
142+
errors.append(f"{original_identifier}: {error}")
141143

142144
# Map results back to original identifiers
143145
identifier_to_results = {}

kfinance/domains/statements/statement_tools.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,10 @@ async def get_financial_statement_from_identifiers(
160160
httpx_client=httpx_client,
161161
)
162162

163-
# Add any errors from the statements API
164-
errors.extend(statements_resp.errors.values())
163+
# Add any errors from the statements API, mapping company_id keys back to identifiers
164+
for company_id_str, error in statements_resp.errors.items():
165+
original_identifier = id_triple_resp.get_identifier_from_company_id(int(company_id_str))
166+
errors.append(f"{original_identifier}: {error}")
165167

166168
# Map results back to original identifiers
167169
identifier_to_results = {}

0 commit comments

Comments
 (0)