|
21 | 21 | import warnings |
22 | 22 | from abc import ABC, abstractmethod |
23 | 23 | from dataclasses import dataclass |
24 | | -from functools import cached_property |
| 24 | +from functools import cached_property, wraps |
25 | 25 | from itertools import chain |
26 | 26 | from types import TracebackType |
27 | 27 | from typing import ( |
|
41 | 41 |
|
42 | 42 | from pydantic import Field |
43 | 43 | from sortedcontainers import SortedList |
| 44 | +from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential |
44 | 45 |
|
| 46 | +from pyiceberg.exceptions import CommitFailedException |
45 | 47 | import pyiceberg.expressions.parser as parser |
46 | 48 | from pyiceberg.expressions import ( |
47 | 49 | AlwaysTrue, |
|
74 | 76 | from pyiceberg.schema import Schema |
75 | 77 | from pyiceberg.table.inspect import InspectTable |
76 | 78 | from pyiceberg.table.metadata import ( |
| 79 | + COMMIT_MAX_RETRY_WAIT_MS, |
| 80 | + COMMIT_MAX_RETRY_WAIT_MS_DEFAULT, |
| 81 | + COMMIT_MIN_RETRY_WAIT_MS, |
| 82 | + COMMIT_MIN_RETRY_WAIT_MS_DEFAULT, |
| 83 | + COMMIT_NUM_RETRIES, |
| 84 | + COMMIT_NUM_RETRIES_DEFAULT, |
77 | 85 | INITIAL_SEQUENCE_NUMBER, |
78 | 86 | TableMetadata, |
79 | 87 | ) |
|
89 | 97 | from pyiceberg.table.update import ( |
90 | 98 | AddPartitionSpecUpdate, |
91 | 99 | AddSchemaUpdate, |
| 100 | + AddSnapshotUpdate, |
92 | 101 | AddSortOrderUpdate, |
93 | 102 | AssertCreate, |
94 | 103 | AssertRefSnapshotId, |
@@ -1059,9 +1068,51 @@ def refs(self) -> Dict[str, SnapshotRef]: |
1059 | 1068 | return self.metadata.refs |
1060 | 1069 |
|
1061 | 1070 | def _do_commit(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequirement, ...]) -> None: |
1062 | | - response = self.catalog.commit_table(self, requirements, updates) |
1063 | | - self.metadata = response.metadata |
1064 | | - self.metadata_location = response.metadata_location |
| 1071 | + def _on_error(*_): |
| 1072 | + nonlocal updates, requirements |
| 1073 | + self.refresh() |
| 1074 | + next_seq_num = self.metadata.next_sequence_number() |
| 1075 | + updates = tuple( |
| 1076 | + ( |
| 1077 | + update.model_copy( |
| 1078 | + update={ |
| 1079 | + "snapshot": update.snapshot.model_copy( |
| 1080 | + update={ |
| 1081 | + "parent_snaphot_id": self.metadata.current_snapshot_id, |
| 1082 | + "sequence_number": next_seq_num, |
| 1083 | + } |
| 1084 | + ), |
| 1085 | + }, |
| 1086 | + ) |
| 1087 | + if isinstance(update, AddSnapshotUpdate) |
| 1088 | + else update |
| 1089 | + ) |
| 1090 | + for update in updates |
| 1091 | + ) |
| 1092 | + requirements = tuple( |
| 1093 | + req.model_copy(update={"snapshot_id": self.metadata.current_snapshot_id}) |
| 1094 | + if isinstance(req, AssertRefSnapshotId) |
| 1095 | + else req |
| 1096 | + for req in requirements |
| 1097 | + ) |
| 1098 | + |
| 1099 | + min_wait_ms = int(self.metadata.properties.get(COMMIT_MIN_RETRY_WAIT_MS, COMMIT_MIN_RETRY_WAIT_MS_DEFAULT)) |
| 1100 | + max_wait_ms = int(self.metadata.properties.get(COMMIT_MAX_RETRY_WAIT_MS, COMMIT_MAX_RETRY_WAIT_MS_DEFAULT)) |
| 1101 | + num_retries = int(self.metadata.properties.get(COMMIT_NUM_RETRIES, COMMIT_NUM_RETRIES_DEFAULT)) |
| 1102 | + |
| 1103 | + @wraps(self._do_commit) |
| 1104 | + @retry( |
| 1105 | + wait=wait_random_exponential(min=min_wait_ms / 1000, max=max_wait_ms / 1000), |
| 1106 | + stop=stop_after_attempt(num_retries), |
| 1107 | + retry=retry_if_exception_type(CommitFailedException), |
| 1108 | + after=_on_error, |
| 1109 | + ) |
| 1110 | + def _do_commit_inner() -> None: |
| 1111 | + response = self.catalog.commit_table(self, requirements, updates) |
| 1112 | + self.metadata = response.metadata |
| 1113 | + self.metadata_location = response.metadata_location |
| 1114 | + |
| 1115 | + return _do_commit_inner() |
1065 | 1116 |
|
1066 | 1117 | def __eq__(self, other: Any) -> bool: |
1067 | 1118 | """Return the equality of two instances of the Table class.""" |
|
0 commit comments