diff --git a/.gitignore b/.gitignore index de762cac9c8..7285116f6ca 100644 --- a/.gitignore +++ b/.gitignore @@ -88,3 +88,6 @@ packages/grid/helm/examples/dev/migration.yaml notebooks/scenarios/bigquery/*.json + +notebooks/tutorials/version-upgrades/*.yaml +notebooks/tutorials/version-upgrades/*.blob diff --git a/notebooks/scenarios/bigquery/sync/040-do-review-requests.ipynb b/notebooks/scenarios/bigquery/sync/040-do-review-requests.ipynb index 07b32abbc34..e9b43ae3eb1 100644 --- a/notebooks/scenarios/bigquery/sync/040-do-review-requests.ipynb +++ b/notebooks/scenarios/bigquery/sync/040-do-review-requests.ipynb @@ -209,8 +209,9 @@ "metadata": {}, "outputs": [], "source": [ - "assert len(diffs.batches) == 1\n", - "assert diffs.batches[0].root_diff.obj_type.__qualname__ == \"Job\"" + "batch_root_strs = [x.root_diff.obj_type.__qualname__ for x in diffs.batches]\n", + "assert len(diffs.batches) == 3\n", + "assert \"Job\" in batch_root_strs" ] }, { @@ -285,11 +286,6 @@ } ], "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -300,7 +296,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.4" + "version": "3.12.5" } }, "nbformat": 4, diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index b47aba7a604..ce179786e4c 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -18,6 +18,13 @@ "action": "add" } }, + "SyftWorker": { + "2": { + "version": 2, + "hash": "e996dabbb8ad4ff0bc5d19528077c11f73b9300d810735d367916e4e5b9149b6", + "action": "add" + } + }, "WorkerSettings": { "2": { "version": 2, @@ -25,6 +32,27 @@ "action": "add" } }, + "ApprovalDecision": { + "1": { + "version": 1, + "hash": "ecce7c6e01af68b0c0a73605f0c2226917f0784ecce69e9f64ce004b243252d4", + "action": "add" + } + }, + "UserCodeStatusCollection": { + "2": { + "version": 2, + "hash": "22a1574d4d2d5dcfa26791f2a5007bf3885dae707e175bf8cc20d0803ae54dec", + "action": "add" + } + }, + "UserCode": { + "2": { + "version": 2, + "hash": "726bc406449178029c04b0b21b50f86ea12b18ea5b7dd030ad7dbfc6e60f6909", + "action": "add" + } + }, "QueueItem": { "2": { "version": 2, @@ -66,13 +94,6 @@ "hash": "2e1365c5535fa51c22eef79f67dd6444789bc829c27881367e3050e06e2ffbfe", "action": "remove" } - }, - "SyftWorker": { - "2": { - "version": 2, - "hash": "e996dabbb8ad4ff0bc5d19528077c11f73b9300d810735d367916e4e5b9149b6", - "action": "add" - } } } } diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py index d6c1a56e801..89871e63c19 100644 --- a/packages/syft/src/syft/service/code/status_service.py +++ b/packages/syft/src/syft/service/code/status_service.py @@ -3,10 +3,13 @@ # third party # relative +from ...client.api import ServerIdentity from ...serde.serializable import serializable from ...store.db.db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store import PartitionSettings +from ...types.syft_object import PartialSyftObject +from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.uid import UID from ..context import AuthedServiceContext from ..response import SyftSuccess @@ -15,6 +18,7 @@ from ..service import service_method from ..user.user_roles import ADMIN_ROLE_LEVEL from ..user.user_roles import GUEST_ROLE_LEVEL +from .user_code import ApprovalDecision from .user_code import UserCodeStatusCollection @@ -26,6 +30,14 @@ class StatusStash(ObjectStash[UserCodeStatusCollection]): ) +class CodeStatusUpdate(PartialSyftObject): + __canonical_name__ = "CodeStatusUpdate" + __version__ = SYFT_OBJECT_VERSION_1 + + id: UID + decision: ApprovalDecision + + @serializable(canonical_name="UserCodeStatusService", version=1) class UserCodeStatusService(AbstractService): stash: StatusStash @@ -39,10 +51,30 @@ def create( context: AuthedServiceContext, status: UserCodeStatusCollection, ) -> UserCodeStatusCollection: - return self.stash.set( + res = self.stash.set( credentials=context.credentials, obj=status, ).unwrap() + return res + + @service_method( + path="code_status.update", + name="update", + roles=ADMIN_ROLE_LEVEL, + autosplat=["code_update"], + unwrap_on_success=False, + ) + def update( + self, context: AuthedServiceContext, code_update: CodeStatusUpdate + ) -> SyftSuccess: + existing_status = self.stash.get_by_uid( + context.credentials, uid=code_update.id + ).unwrap() + server_identity = ServerIdentity.from_server(context.server) + existing_status.status_dict[server_identity] = code_update.decision + + res = self.stash.update(context.credentials, existing_status).unwrap() + return SyftSuccess(message="UserCode updated successfully", value=res) @service_method( path="code_status.get_by_uid", name="get_by_uid", roles=GUEST_ROLE_LEVEL diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 81e586d5c02..fea4927cb82 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -51,6 +51,7 @@ from ...types.dicttuple import DictTuple from ...types.errors import SyftException from ...types.result import as_result +from ...types.syft_migration import migrate from ...types.syft_object import PartialSyftObject from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 @@ -58,7 +59,9 @@ from ...types.syncable_object import SyncableSyftObject from ...types.transforms import TransformContext from ...types.transforms import add_server_uid_for_key +from ...types.transforms import drop from ...types.transforms import generate_id +from ...types.transforms import make_set_default from ...types.transforms import transform from ...types.uid import UID from ...util.decorators import deprecated @@ -128,14 +131,159 @@ def __hash__(self) -> int: @serializable() -class UserCodeStatusCollection(SyncableSyftObject): +class ApprovalDecision(SyftObject): + status: UserCodeStatus + reason: str | None = None + + __canonical_name__ = "ApprovalDecision" + __version__ = 1 + + @property + def reason_or_none(self) -> str | None: + # TODO: move to class creation + if self.reason == "": + return None + return self.reason + + +@serializable() +class UserCodeStatusCollectionV1(SyncableSyftObject): + """Currently this is a class that implements a mixed bag of two statusses + The first status is for a level 0 Request, which only uses the status dict + for denied decision. If there is no denied decision, it computes the status + by checking the backend for whether it has readable outputs. + The second use case is for a level 2 Request, in this case we store the status + dict on the object and use it as is for both denied and approved status + """ + __canonical_name__ = "UserCodeStatusCollection" __version__ = SYFT_OBJECT_VERSION_1 __repr_attrs__ = ["approved", "status_dict"] + + # this is empty in the case of l0 status_dict: dict[ServerIdentity, tuple[UserCodeStatus, str]] = {} + user_code_link: LinkedObject + +@serializable() +class UserCodeStatusCollection(SyncableSyftObject): + """Currently this is a class that implements a mixed bag of two statusses + The first status is for a level 0 Request, which only uses the status dict + for denied decision. If there is no denied decision, it computes the status + by checking the backend for whether it has readable outputs. + The second use case is for a level 2 Request, in this case we store the status + dict on the object and use it as is for both denied and approved status + """ + + __canonical_name__ = "UserCodeStatusCollection" + __version__ = SYFT_OBJECT_VERSION_2 + + __repr_attrs__ = ["approved", "status_dict"] + + # this is empty in the case of l0 + status_dict: dict[ServerIdentity, ApprovalDecision] = {} + + user_code_link: LinkedObject + user_verify_key: SyftVerifyKey + + was_requested_on_lowside: bool = False + + # ugly and buggy optimization, remove at some point + _has_readable_outputs_cache: bool | None = None + + @property + def approved(self) -> bool: + # only use this on the client side, in this case we can use self.get_api instead + # of using the context + return self.get_is_approved(None) + + def get_is_approved(self, context: AuthedServiceContext | None) -> bool: + return self._compute_status(context) == UserCodeStatus.APPROVED + + def _compute_status( + self, context: AuthedServiceContext | None = None + ) -> UserCodeStatus: + if self.was_requested_on_lowside: + return self._compute_status_l0(context) + else: + return self._compute_status_l2() + + @property + def denied(self) -> bool: + # for denied we use the status dict both for level 0 and level 2 + return any( + approval_dec.status == UserCodeStatus.DENIED + for approval_dec in self.status_dict.values() + ) + + def _compute_status_l0( + self, context: AuthedServiceContext | None = None + ) -> UserCodeStatus: + # for l0, if denied in status dict, its denied + # if not, and it has readable outputs, its approved, + # else pending + + has_readable_outputs = self._has_readable_outputs(context) + + if self.denied: + if has_readable_outputs: + prompt_warning_message( + "This request already has results published to the data scientist. " + "They will still be able to access those results." + ) + return UserCodeStatus.DENIED + elif has_readable_outputs: + return UserCodeStatus.APPROVED + else: + return UserCodeStatus.PENDING + + def _compute_status_l2(self) -> UserCodeStatus: + any_denied = any( + approval_dec.status == UserCodeStatus.DENIED + for approval_dec in self.status_dict.values() + ) + all_approved = all( + approval_dec.status == UserCodeStatus.APPROVED + for approval_dec in self.status_dict.values() + ) + if any_denied: + return UserCodeStatus.DENIED + elif all_approved: + return UserCodeStatus.APPROVED + else: + return UserCodeStatus.PENDING + + def _has_readable_outputs( + self, context: AuthedServiceContext | None = None + ) -> bool: + if context is None: + # Clientside + api = self._get_api() + if self._has_readable_outputs_cache is None: + has_readable_outputs = api.output.has_output_read_permissions( + self.user_code_link.object_uid, self.user_verify_key + ) + self._has_readable_outputs_cache = has_readable_outputs + return has_readable_outputs + else: + return self._has_readable_outputs_cache + else: + # Serverside + return context.server.services.output.has_output_read_permissions( + context, self.user_code_link.object_uid, self.user_verify_key + ) + + @property + def first_denial_reason(self) -> str: + denial_reasons = [ + x.reason_or_none + for x in self.status_dict.values() + if x.status == UserCodeStatus.DENIED and x.reason_or_none is not None + ] + return next(iter(denial_reasons), "") + def syft_get_diffs(self, ext_obj: Any) -> list[AttrDiff]: # relative from ...service.sync.diff_state import AttrDiff @@ -162,87 +310,51 @@ def _repr_html_(self) -> str:

User Code Status

""" - for server_identity, (status, reason) in self.status_dict.items(): + for server_identity, approval_decision in self.status_dict.items(): server_name_str = f"{server_identity.server_name}" uid_str = f"{server_identity.server_id}" - status_str = f"{status.value}" + status_str = f"{approval_decision.status.value}" string += f""" • UID: {uid_str}  Server name: {server_name_str}  Status: {status_str}; - Reason: {reason} + Reason: {approval_decision.reason}
""" string += "

" return string def __repr_syft_nested__(self) -> str: - string = "" - for server_identity, (status, reason) in self.status_dict.items(): - string += f"{server_identity.server_name}: {status}, {reason}
" - return string + # this currently assumes that there is only one status + status_str = self._compute_status().value + + if self.denied: + status_str = f"{status_str}: self.first_denial_reason" + return status_str - def get_status_message(self) -> str: - if self.approved: + def get_status_message_l2(self, context: AuthedServiceContext) -> str: + if self.get_is_approved(context): return f"{type(self)} approved" denial_string = "" string = "" - for server_identity, (status, reason) in self.status_dict.items(): - denial_string += f"Code status on server '{server_identity.server_name}' is '{status}'. Reason: {reason}" - if not reason.endswith("."): - denial_string += "." - string += ( - f"Code status on server '{server_identity.server_name}' is '{status}'." + + for server_identity, approval_decision in self.status_dict.items(): + denial_string += ( + f"Code status on server '{server_identity.server_name}' is '{approval_decision.status}'." + f" Reason: {approval_decision.reason}" ) + if approval_decision.reason and not approval_decision.reason.endswith("."): # type: ignore + denial_string += "." + string += f"Code status on server '{server_identity.server_name}' is '{approval_decision.status}'." if self.denied: return f"{type(self)} Your code cannot be run: {denial_string}" else: return f"{type(self)} Your code is waiting for approval. {string}" - @property - def approved(self) -> bool: - return all(x == UserCodeStatus.APPROVED for x, _ in self.status_dict.values()) - - @property - def denied(self) -> bool: - for status, _ in self.status_dict.values(): - if status == UserCodeStatus.DENIED: - return True - return False - - def for_user_context(self, context: AuthedServiceContext) -> UserCodeStatus: - if context.server.server_type == ServerType.ENCLAVE: - keys = {status for status, _ in self.status_dict.values()} - if len(keys) == 1 and UserCodeStatus.APPROVED in keys: - return UserCodeStatus.APPROVED - elif UserCodeStatus.PENDING in keys and UserCodeStatus.DENIED not in keys: - return UserCodeStatus.PENDING - elif UserCodeStatus.DENIED in keys: - return UserCodeStatus.DENIED - else: - raise Exception(f"Invalid types in {keys} for Code Submission") - - elif context.server.server_type == ServerType.DATASITE: - server_identity = ServerIdentity( - server_name=context.server.name, - server_id=context.server.id, - verify_key=context.server.signing_key.verify_key, - ) - if server_identity in self.status_dict: - return self.status_dict[server_identity][0] - else: - raise Exception( - f"Code Object does not contain {context.server.name} Datasite's data" - ) - else: - raise Exception( - f"Invalid Server Type for Code Submission:{context.server.server_type}" - ) - @as_result(SyftException) def mutate( self, - value: tuple[UserCodeStatus, str], + value: ApprovalDecision, server_name: str, server_id: UID, verify_key: SyftVerifyKey, @@ -264,8 +376,41 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: return [self.user_code_link.object_uid] +@migrate(UserCodeStatusCollectionV1, UserCodeStatusCollection) +def migrate_user_code_status_to_v2() -> list[Callable]: + def update_statusdict(context: TransformContext) -> TransformContext: + res = {} + if not isinstance(context.obj, UserCodeStatusCollectionV1): + raise Exception("Invalid object type") + if context.output is None: + raise Exception("Output is None") + for server_identity, (status, reason) in context.obj.status_dict.items(): + res[server_identity] = ApprovalDecision(status=status, reason=reason) + context.output["status_dict"] = res + return context + + def set_user_verify_key(context: TransformContext) -> TransformContext: + authed_context = context.to_server_context() + if not isinstance(context.obj, UserCodeStatusCollectionV1): + raise Exception("Invalid object type") + if context.output is None: + raise Exception("Output is None") + user_code = context.obj.user_code_link.resolve_with_context( + authed_context + ).unwrap() + context.output["user_verify_key"] = user_code.user_verify_key + return context + + return [ + make_set_default("was_requested_on_lowside", False), + make_set_default("_has_readable_outputs_cache", None), + update_statusdict, + set_user_verify_key, + ] + + @serializable() -class UserCode(SyncableSyftObject): +class UserCodeV1(SyncableSyftObject): # version __canonical_name__ = "UserCode" __version__ = SYFT_OBJECT_VERSION_1 @@ -335,6 +480,77 @@ class UserCode(SyncableSyftObject): "output_policy_state", ] + +@serializable() +class UserCode(SyncableSyftObject): + # version + __canonical_name__ = "UserCode" + __version__ = SYFT_OBJECT_VERSION_2 + + id: UID + server_uid: UID | None = None + user_verify_key: SyftVerifyKey + raw_code: str + input_policy_type: type[InputPolicy] | UserPolicy + input_policy_init_kwargs: dict[Any, Any] | None = None + input_policy_state: bytes = b"" + output_policy_type: type[OutputPolicy] | UserPolicy + output_policy_init_kwargs: dict[Any, Any] | None = None + output_policy_state: bytes = b"" + parsed_code: str + service_func_name: str + unique_func_name: str + user_unique_func_name: str + code_hash: str + signature: inspect.Signature + status_link: LinkedObject + input_kwargs: list[str] + submit_time: DateTime | None = None + # tracks if the code calls datasite.something, variable is set during parsing + uses_datasite: bool = False + + nested_codes: dict[str, tuple[LinkedObject, dict]] | None = {} + worker_pool_name: str | None = None + origin_server_side_type: ServerSideType + # l0_deny_reason: str | None = None + + __table_coll_widths__ = [ + "min-content", + "auto", + "auto", + "auto", + "auto", + "auto", + "auto", + "auto", + ] + + __attr_searchable__: ClassVar[list[str]] = [ + "user_verify_key", + "service_func_name", + "code_hash", + ] + __attr_unique__: ClassVar[list[str]] = [] + __repr_attrs__: ClassVar[list[str]] = [ + "service_func_name", + "input_owners", + "status", + "worker_pool_name", + # "l0_deny_reason", + "raw_code", + ] + + __exclude_sync_diff_attrs__: ClassVar[list[str]] = [ + "server_uid", + "code_status", + "input_policy_type", + "input_policy_init_kwargs", + "input_policy_state", + "output_policy_type", + "output_policy_init_kwargs", + "output_policy_state", + ] + @field_validator("service_func_name", mode="after") @classmethod def service_func_name_is_valid(cls, value: str) -> str: @@ -356,14 +572,14 @@ def __setattr__(self, key: str, value: Any) -> None: return super().__setattr__(key, value) def _coll_repr_(self) -> dict[str, Any]: - status = [status for status, _ in self.status.status_dict.values()][0].value - if status == UserCodeStatus.PENDING.value: + status = self.status._compute_status() + if status == UserCodeStatus.PENDING: badge_color = "badge-purple" - elif status == UserCodeStatus.APPROVED.value: + elif status == UserCodeStatus.APPROVED: badge_color = "badge-green" else: badge_color = "badge-red" - status_badge = {"value": status, "type": badge_color} + status_badge = {"value": status.value, "type": badge_color} return { "Input Policy": self.input_policy_type.__canonical_name__, "Output Policy": self.output_policy_type.__canonical_name__, @@ -389,86 +605,16 @@ def user(self) -> UserView: api = self.get_api() return api.services.user.get_by_verify_key(self.user_verify_key) - def _compute_status_l0( - self, context: AuthedServiceContext | None = None - ) -> UserCodeStatusCollection: - if context is None: - # Clientside - api = self._get_api() - server_identity = ServerIdentity.from_api(api) - - if self._has_output_read_permissions_cache is None: - is_approved = api.output.has_output_read_permissions( - self.id, self.user_verify_key - ) - self._has_output_read_permissions_cache = is_approved - else: - is_approved = self._has_output_read_permissions_cache - else: - # Serverside - server_identity = ServerIdentity.from_server(context.server) - is_approved = context.server.services.output.has_output_read_permissions( - context, self.id, self.user_verify_key - ) - is_denied = self.l0_deny_reason is not None - - if is_denied: - if is_approved: - prompt_warning_message( - "This request already has results published to the data scientist. " - "They will still be able to access those results." - ) - message = self.l0_deny_reason - status = (UserCodeStatus.DENIED, message) - elif is_approved: - status = (UserCodeStatus.APPROVED, "") - else: - status = (UserCodeStatus.PENDING, "") - status_dict = {server_identity: status} - - return UserCodeStatusCollection( - status_dict=status_dict, - user_code_link=LinkedObject.from_obj(self), - ) - @property def status(self) -> UserCodeStatusCollection: - # Clientside only - - if self.is_l0_deployment: - if self.status_link is not None: - raise SyftException( - public_message="Encountered a low side UserCode object with a status_link." - ) - return self._compute_status_l0() - - if self.status_link is None: - raise SyftException( - public_message="This UserCode does not have a status. Please contact the Admin." - ) - res = self.status_link.resolve - return res - - @as_result(SyftException) - def get_status(self, context: AuthedServiceContext) -> UserCodeStatusCollection: - if self.is_l0_deployment: - if self.status_link is not None: - raise SyftException( - public_message="Encountered a low side UserCode object with a status_link." - ) - return self._compute_status_l0(context) - - if self.status_link is None: - raise SyftException( - public_message="This UserCode does not have a status. Please contact the Admin." - ) - - return self.status_link.resolve_with_context(context).unwrap() + # only use this client side + return self.get_status(None).unwrap() @as_result(SyftException) - def is_status_approved(self, context: AuthedServiceContext) -> bool: - status = self.get_status(context).unwrap() - return status.approved + def get_status( + self, context: AuthedServiceContext | None + ) -> UserCodeStatusCollection: + return self.status_link.resolve_dynamic(context, load_cached=False) @property def input_owners(self) -> list[str] | None: @@ -502,13 +648,8 @@ def output_readers(self) -> list[SyftVerifyKey] | None: return None @property - def code_status(self) -> list: - status_list = [] - for server_view, (status, _) in self.status.status_dict.items(): - status_list.append( - f"Server: {server_view.server_name}, Status: {status.value}", - ) - return status_list + def code_status_str(self) -> str: + return f"Status: {self.status._compute_status().value}" @property def input_policy(self) -> InputPolicy | None: @@ -518,7 +659,7 @@ def input_policy(self) -> InputPolicy | None: def get_input_policy(self, context: AuthedServiceContext) -> InputPolicy | None: status = self.get_status(context).unwrap() - if status.approved or self.input_policy_type.has_safe_serde: + if status.get_is_approved(context) or self.input_policy_type.has_safe_serde: return self._get_input_policy() return None @@ -580,7 +721,7 @@ def input_policy(self, value: Any) -> None: # type: ignore def get_output_policy(self, context: AuthedServiceContext) -> OutputPolicy | None: status = self.get_status(context).unwrap() - if status.approved or self.output_policy_type.has_safe_serde: + if status.get_is_approved(context) or self.output_policy_type.has_safe_serde: return self._get_output_policy() return None @@ -876,7 +1017,7 @@ def _inner_repr(self, level: int = 0) -> str: id: UID = {self.id} service_func_name: str = {self.service_func_name} shareholders: list = {self.input_owners} - status: list = {self.code_status} + status: str = {self.code_status_str} {constants_str} {shared_with_line} inputs: dict = {inputs_str} @@ -931,7 +1072,7 @@ def _ipython_display_(self, level: int = 0) -> None:

{tabs}id: UID = {self.id}

{tabs}service_func_name: str = {self.service_func_name}

{tabs}shareholders: list = {self.input_owners}

-

{tabs}status: list = {self.code_status}

+

{tabs}status: str = {self.code_status_str}

{tabs}{constants_str} {tabs}{shared_with_line}

{tabs}inputs: dict =

{self._inputs_json}

@@ -1533,11 +1674,14 @@ def create_code_status(context: TransformContext) -> TransformContext: if context.output is None: return context - # Low side requests have a computed status - if context.server.server_side_type == ServerSideType.LOW_SIDE: - return context + # # Low side requests have a computed status + # if + # return context + + was_requested_on_lowside = ( + context.server.server_side_type == ServerSideType.LOW_SIDE + ) - input_keys = list(context.output["input_policy_init_kwargs"].keys()) code_link = LinkedObject.from_uid( context.output["id"], UserCode, @@ -1551,15 +1695,23 @@ def create_code_status(context: TransformContext) -> TransformContext: verify_key=context.server.signing_key.verify_key, ) status = UserCodeStatusCollection( - status_dict={server_identity: (UserCodeStatus.PENDING, "")}, + status_dict={ + server_identity: ApprovalDecision(status=UserCodeStatus.PENDING) + }, user_code_link=code_link, + user_verify_key=context.credentials, + was_requested_on_lowside=was_requested_on_lowside, ) elif context.server.server_type == ServerType.ENCLAVE: - status_dict = {key: (UserCodeStatus.PENDING, "") for key in input_keys} + input_keys = list(context.output["input_policy_init_kwargs"].keys()) + status_dict = { + key: ApprovalDecision(status=UserCodeStatus.PENDING) for key in input_keys + } status = UserCodeStatusCollection( status_dict=status_dict, user_code_link=code_link, + user_verify_key=context.credentials, ) else: raise NotImplementedError( @@ -1944,3 +2096,8 @@ def load_approved_policy_code( load_policy_code(user_code.input_policy_type) if isinstance(user_code.output_policy_type, UserPolicy): load_policy_code(user_code.output_policy_type) + + +@migrate(UserCodeV1, UserCode) +def migrate_user_code_to_v2() -> list[Callable]: + return [drop("l0_deny_reason"), drop("_has_output_read_permissions_cache")] diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 79490223f45..5ba617ef62e 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -12,7 +12,6 @@ from ...types.errors import SyftException from ...types.result import Err from ...types.result import as_result -from ...types.syft_metaclass import Empty from ...types.twin_object import TwinObject from ...types.uid import UID from ..action.action_object import ActionObject @@ -148,14 +147,7 @@ def update( context: AuthedServiceContext, code_update: UserCodeUpdate, ) -> SyftSuccess: - code = self.stash.get_by_uid(context.credentials, code_update.id).unwrap() - # FIX: Check if this works (keep commented): - # self.stash.update(context.credentials, code).unwrap() - - if code_update.l0_deny_reason is not Empty: # type: ignore[comparison-overlap] - code.l0_deny_reason = code_update.l0_deny_reason - - updated_code = self.stash.update(context.credentials, code).unwrap() + updated_code = self.stash.update(context.credentials, code_update).unwrap() return SyftSuccess(message="UserCode updated successfully", value=updated_code) @service_method( @@ -353,7 +345,7 @@ def is_execution_allowed( output_policy: OutputPolicy | None, ) -> IsExecutionAllowedEnum: status = code.get_status(context).unwrap() - if not status.approved: + if not status.get_is_approved(context): return IsExecutionAllowedEnum.NOT_APPROVED elif self.has_code_permission(code, context) is HasCodePermissionEnum.DENIED: # TODO: Check enum above @@ -503,8 +495,13 @@ def _call( # code is from low side (L0 setup) status = code.get_status(context).unwrap() - if not status.approved: - raise SyftException(public_message=status.get_status_message()) + if ( + context.server_allows_execution_for_ds + and not status.get_is_approved(context) + ): + raise SyftException( + public_message=status.get_status_message_l2(context) + ) output_policy_is_valid = False try: @@ -640,7 +637,10 @@ def store_execution_output( is_admin = context.role == ServiceRole.ADMIN - if not code.is_status_approved(context) and not is_admin: + if ( + not code.get_status(context).unwrap().get_is_approved(context) + and not is_admin + ): raise SyftException(public_message="This UserCode is not approved") return code.store_execution_output( diff --git a/packages/syft/src/syft/service/context.py b/packages/syft/src/syft/service/context.py index 6e4037719f6..4ce07c67982 100644 --- a/packages/syft/src/syft/service/context.py +++ b/packages/syft/src/syft/service/context.py @@ -59,6 +59,11 @@ def is_l0_lowside(self) -> bool: """Returns True if this is a low side of a Level 0 deployment""" return self.server.server_side_type == ServerSideType.LOW_SIDE + @property + def server_allows_execution_for_ds(self) -> bool: + """Returns True if this is a low side of a Level 0 deployment""" + return not self.is_l0_lowside + def as_root_context(self) -> Self: return AuthedServiceContext( credentials=self.server.verify_key, diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index b1d461e4e80..eef1a113af7 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -246,40 +246,47 @@ def create_migrated_objects( migrated_objects: list[SyftObject], ignore_existing: bool = True, ) -> SyftSuccess: - return self._create_migrated_objects( + self._create_migrated_objects( context, migrated_objects, ignore_existing=ignore_existing ).unwrap() + return SyftSuccess(message="Created migration objects!") @as_result(SyftException) def _create_migrated_objects( self, context: AuthedServiceContext, - migrated_objects: list[SyftObject], + migrated_objects: dict[type[SyftObject], list[SyftObject]], ignore_existing: bool = True, - ) -> SyftSuccess: - for migrated_object in migrated_objects: - stash = self._search_stash_for_klass( - context, type(migrated_object) - ).unwrap() - - result = stash.set( - context.credentials, - obj=migrated_object, - ) - # Exception from the new Error Handling pattern, no need to change - if result.is_err(): - # TODO: subclass a DuplicationKeyError - if ignore_existing and ( - "Duplication Key Error" in result.err()._private_message # type: ignore - or "Duplication Key Error" in result.err().public_message # type: ignore - ): - print( - f"{type(migrated_object)} #{migrated_object.id} already exists" - ) - continue - else: - result.unwrap() # this will raise the exception inside the wrapper - return SyftSuccess(message="Created migrate objects!") + skip_check_type: bool = False, + ) -> dict[type[SyftObject], list[SyftObject]]: + created_objects: dict[type[SyftObject], list[SyftObject]] = {} + for key, objects in migrated_objects.items(): + created_objects[key] = [] + for migrated_object in objects: + stash = self._search_stash_for_klass( + context, type(migrated_object) + ).unwrap() + + result = stash.set( + context.credentials, + obj=migrated_object, + skip_check_type=skip_check_type, + ) + # Exception from the new Error Handling pattern, no need to change + if result.is_err(): + # TODO: subclass a DuplicationKeyError + if ignore_existing and ( + "Duplication Key Error" in result.err()._private_message # type: ignore + or "Duplication Key Error" in result.err().public_message # type: ignore + ): + print( + f"{type(migrated_object)} #{migrated_object.id} already exists" + ) + continue + else: + result.unwrap() # this will raise the exception inside the wrapper + created_objects[key].append(result.unwrap()) + return created_objects @as_result(SyftException) def _update_migrated_objects( @@ -304,6 +311,7 @@ def _migrate_objects( migration_objects: dict[type[SyftObject], list[SyftObject]], ) -> list[SyftObject]: migrated_objects = [] + for klass, objects in migration_objects.items(): canonical_name = klass.__canonical_name__ latest_version = SyftObjectRegistry.get_latest_version(canonical_name) @@ -435,11 +443,16 @@ def apply_migration_data( "please use 'client.load_migration_data' instead." ) + created_objects = self._create_migrated_objects( + context, migration_data.store_objects, skip_check_type=True + ).unwrap() + # migrate + apply store objects migrated_objects = self._migrate_objects( - context, migration_data.store_objects + context, + created_objects, ).unwrap() - self._create_migrated_objects(context, migrated_objects).unwrap() + self._update_migrated_objects(context, migrated_objects).unwrap() # migrate+apply action objects migrated_actionobjects = self._migrate_objects( diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 674c66a019a..4382ce383c7 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -42,6 +42,7 @@ from ..action.action_object import ActionObject from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import ActionPermission +from ..code.user_code import ApprovalDecision from ..code.user_code import UserCode from ..code.user_code import UserCodeStatus from ..code.user_code import UserCodeStatusCollection @@ -66,8 +67,10 @@ class RequestStatus(Enum): APPROVED = 2 @classmethod - def from_usercode_status(cls, status: UserCodeStatusCollection) -> "RequestStatus": - if status.approved: + def from_usercode_status( + cls, status: UserCodeStatusCollection, context: AuthedServiceContext + ) -> "RequestStatus": + if status.get_is_approved(context): return RequestStatus.APPROVED elif status.denied: return RequestStatus.REJECTED @@ -484,6 +487,15 @@ def code_id(self) -> UID: public_message="This type of request does not have code associated with it." ) + @property + def status_id(self) -> UID: + for change in self.changes: + if isinstance(change, UserCodeStatusChange): + return change.linked_obj.object_uid # type: ignore + raise SyftException( + public_message="This type of request does not have code associated with it." + ) + @property def codes(self) -> Any: for change in self.changes: @@ -532,7 +544,7 @@ def get_status(self, context: AuthedServiceContext | None = None) -> RequestStat code_status = ( self.code.get_status(context) if context else self.code.status ) - return RequestStatus.from_usercode_status(code_status) + return RequestStatus.from_usercode_status(code_status, context) except Exception: # nosec # this breaks when coming from a user submitting a request # which tries to send an email to the admin and ends up here @@ -613,7 +625,11 @@ def deny(self, reason: str) -> SyftSuccess: "This request already has results published to the data scientist. " "They will still be able to access those results." ) - api.code.update(id=self.code_id, l0_deny_reason=reason) + api.code_status.update( + id=self.code.status_link.object_uid, + decision=ApprovalDecision(status=UserCodeStatus.DENIED, reason=reason), + ) + return SyftSuccess(message=f"Request denied with reason: {reason}") return api.services.request.undo(uid=self.id, reason=reason) @@ -1079,9 +1095,7 @@ def accept_by_depositing_result(self, result: Any, force: bool = False) -> Any: pass def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: - dependencies = [] - dependencies.append(self.code_id) - return dependencies + return [self.code_id, self.status_id] @serializable() @@ -1427,7 +1441,9 @@ def mutate( ) -> UserCodeStatusCollection: reason: str = context.extra_kwargs.get("reason", "") return status.mutate( - value=(UserCodeStatus.DENIED if undo else self.value, reason), + value=ApprovalDecision( + status=UserCodeStatus.DENIED if undo else self.value, reason=reason + ), server_name=context.server.name, server_id=context.server.id, verify_key=context.server.signing_key.verify_key, diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index d9340895fa7..d61356ed36a 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -921,15 +921,15 @@ def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: return "" # Turns off the _repr_markdown_ of SyftObject def _get_visual_hierarchy( - self, server: ObjectDiff, visited: set[UID] | None = None + self, node: ObjectDiff, visited: set[UID] | None = None ) -> dict[ObjectDiff, dict]: visited = visited if visited is not None else set() - visited.add(server.object_id) + visited.add(node.object_id) _, child_types_map = self.visual_hierarchy - child_types = child_types_map.get(server.obj_type, []) - dep_ids = self.dependencies.get(server.object_id, []) + self.dependents.get( - server.object_id, [] + child_types = child_types_map.get(node.obj_type, []) + dep_ids = self.dependencies.get(node.object_id, []) + self.dependents.get( + node.object_id, [] ) result = {} @@ -1444,10 +1444,10 @@ def _create_batches( root_ids.append(diff.object_id) # type: ignore # Dependents are the reverse edges of the dependency graph - obj_dependents = {} + obj_dependents: dict = {} for parent, children in obj_dependencies.items(): for child in children: - obj_dependents[child] = obj_dependencies.get(child, []) + [parent] + obj_dependents[child] = obj_dependents.get(child, []) + [parent] for root_uid in root_ids: batch = ObjectDiffBatch.from_dependencies( diff --git a/packages/syft/src/syft/service/sync/resolve_widget.py b/packages/syft/src/syft/service/sync/resolve_widget.py index d27c106b49f..9b41fc13c70 100644 --- a/packages/syft/src/syft/service/sync/resolve_widget.py +++ b/packages/syft/src/syft/service/sync/resolve_widget.py @@ -505,20 +505,25 @@ def batch_diff_widgets(self) -> list[CollapsableObjectDiffWidget]: return dependent_diff_widgets @property - def dependent_root_diff_widgets(self) -> list[CollapsableObjectDiffWidget]: + def dependency_root_diff_widgets(self) -> list[CollapsableObjectDiffWidget]: dependencies = self.obj_diff_batch.get_dependencies( include_roots=True, include_batch_root=False ) - other_roots = [ - d for d in dependencies if d.object_id in self.obj_diff_batch.global_roots - ] + + # we show these above the line + dependents = self.obj_diff_batch.get_dependents( + include_roots=False, include_batch_root=False + ) + dependent_ids = [x.object_id for x in dependents] + # we skip the ones we already show above the line in the widget + context_diffs = [d for d in dependencies if d.object_id not in dependent_ids] widgets = [ CollapsableObjectDiffWidget( diff, direction=self.obj_diff_batch.sync_direction, build_state=self.build_state, ) - for diff in other_roots + for diff in context_diffs ] return widgets @@ -559,7 +564,7 @@ def build(self) -> VBox: self.id2widget = {} batch_diff_widgets = self.batch_diff_widgets - dependent_batch_diff_widgets = self.dependent_root_diff_widgets + dependent_batch_diff_widgets = self.dependency_root_diff_widgets main_object_diff_widget = self.main_object_diff_widget self.id2widget[main_object_diff_widget.diff.object_id] = main_object_diff_widget diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index ddafd86b1d3..4aadaaa079a 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -100,9 +100,9 @@ def transform_item( if isinstance(item, UserCodeStatusCollection): identity = ServerIdentity.from_server(context.server) res = {} - for key in item.status_dict.keys(): + for approval_decision in item.status_dict.values(): # todo, check if they are actually only two servers - res[identity] = item.status_dict[key] + res[identity] = approval_decision item.status_dict = res self.set_obj_ids(context, item) diff --git a/packages/syft/src/syft/service/worker/worker_image_stash.py b/packages/syft/src/syft/service/worker/worker_image_stash.py index dc220905839..29755e9ef07 100644 --- a/packages/syft/src/syft/service/worker/worker_image_stash.py +++ b/packages/syft/src/syft/service/worker/worker_image_stash.py @@ -33,6 +33,7 @@ def set( add_storage_permission: bool = True, ignore_duplicates: bool = False, session: Session = None, + skip_check_type: bool = False, ) -> SyftWorkerImage: # By default syft images have all read permission add_permissions = [] if add_permissions is None else add_permissions diff --git a/packages/syft/src/syft/service/worker/worker_pool_stash.py b/packages/syft/src/syft/service/worker/worker_pool_stash.py index 81a4f4741d2..3ae0a2d9ec2 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_stash.py +++ b/packages/syft/src/syft/service/worker/worker_pool_stash.py @@ -42,6 +42,7 @@ def set( add_storage_permission: bool = True, ignore_duplicates: bool = False, session: Session = None, + skip_check_type: bool = False, ) -> WorkerPool: # By default all worker pools have all read permission add_permissions = [] if add_permissions is None else add_permissions diff --git a/packages/syft/src/syft/service/worker/worker_stash.py b/packages/syft/src/syft/service/worker/worker_stash.py index 48a192ecd19..d64314a5d81 100644 --- a/packages/syft/src/syft/service/worker/worker_stash.py +++ b/packages/syft/src/syft/service/worker/worker_stash.py @@ -35,6 +35,7 @@ def set( add_storage_permission: bool = True, ignore_duplicates: bool = False, session: Session = None, + skip_check_type: bool = False, ) -> SyftWorker: # By default all worker pools have all read permission add_permissions = [] if add_permissions is None else add_permissions diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index aec2a2ed9c5..85860821866 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -42,6 +42,7 @@ from ...util.telemetry import instrument from ..document_store_errors import NotFoundException from ..document_store_errors import StashException +from ..document_store_errors import UniqueConstraintException from .db import DBManager from .query import Query from .schema import PostgresBase @@ -204,7 +205,8 @@ def is_unique(self, obj: StashT, session: Session = None) -> bool: return False elif len(results) == 1: result = results[0] - return result.id == obj.id + res = result.id == obj.id + return res return True @with_session @@ -360,8 +362,9 @@ def set( add_storage_permission: bool = True, # TODO: check the default value ignore_duplicates: bool = False, session: Session = None, + skip_check_type: bool = False, ) -> StashT: - if not self.allow_any_type: + if not self.allow_any_type and not skip_check_type: self.check_type(obj, self.object_type).unwrap() uid = obj.id @@ -427,7 +430,13 @@ def apply_partial_update( self.object_type.model_validate(original_obj) return original_obj - @as_result(StashException, NotFoundException, AttributeError, ValidationError) + @as_result( + StashException, + NotFoundException, + AttributeError, + ValidationError, + UniqueConstraintException, + ) @with_session def update( self, @@ -454,7 +463,9 @@ def update( # TODO has_permission is not used if not self.is_unique(obj): - raise StashException(f"Some fields are not unique for {type(obj).__name__}") + raise UniqueConstraintException( + f"Some fields are not unique for {type(obj).__name__} and unique fields {self.unique_fields}" + ) stmt = self.table.update().where(self._get_field_filter("id", obj.id)) stmt = self._apply_permission_filter( diff --git a/packages/syft/src/syft/store/document_store_errors.py b/packages/syft/src/syft/store/document_store_errors.py index 69da6b73a8f..04fb6777897 100644 --- a/packages/syft/src/syft/store/document_store_errors.py +++ b/packages/syft/src/syft/store/document_store_errors.py @@ -14,6 +14,10 @@ class StashException(SyftException): public_message = "There was an error retrieving data. Contact your admin." +class UniqueConstraintException(StashException): + public_message = "Another item with the same unique constraint already exists." + + class ObjectCRUDPermissionException(SyftException): public_message = "You do not have permission to perform this action." diff --git a/packages/syft/src/syft/store/linked_obj.py b/packages/syft/src/syft/store/linked_obj.py index d3e40372842..2343dc0b9a6 100644 --- a/packages/syft/src/syft/store/linked_obj.py +++ b/packages/syft/src/syft/store/linked_obj.py @@ -42,7 +42,12 @@ def __str__(self) -> str: @property def resolve(self) -> SyftObject: + return self._resolve() + + def _resolve(self, load_cached: bool = False) -> SyftObject: api = None + if load_cached and self._resolve_cache is not None: + return self._resolve_cache try: # relative api = self.get_api() # raises @@ -53,15 +58,29 @@ def resolve(self) -> SyftObject: logger.error(">>> Failed to resolve object", type(api), e) raise e + def resolve_dynamic( + self, context: ServerServiceContext | None, load_cached: bool = False + ) -> SyftObject: + if context is not None: + return self.resolve_with_context(context, load_cached).unwrap() + else: + return self._resolve(load_cached) + @as_result(SyftException) - def resolve_with_context(self, context: ServerServiceContext) -> Any: + def resolve_with_context( + self, context: ServerServiceContext, load_cached: bool = False + ) -> Any: + if load_cached and self._resolve_cache is not None: + return self._resolve_cache if context.server is None: raise ValueError(f"context {context}'s server is None") - return ( + res = ( context.server.get_service(self.service_type) .resolve_link(context=context, linked_obj=self) .unwrap() ) + self._resolve_cache = res + return res def update_with_context( self, context: ServerServiceContext | ChangeContext | Any, obj: Any diff --git a/packages/syft/src/syft/util/notebook_ui/components/sync.py b/packages/syft/src/syft/util/notebook_ui/components/sync.py index 94e54c60aed..693d9549367 100644 --- a/packages/syft/src/syft/util/notebook_ui/components/sync.py +++ b/packages/syft/src/syft/util/notebook_ui/components/sync.py @@ -13,6 +13,7 @@ from ....service.user.user import UserView from ....types.datetime import DateTime from ....types.datetime import format_timedelta_human_readable +from ....types.errors import SyftException from ....types.syft_object import SYFT_OBJECT_VERSION_1 from ....types.syft_object import SyftObject from ..icons import Icon @@ -97,12 +98,10 @@ def get_status_str(self) -> str: return f"Status: {self.object.status.value}" elif isinstance(self.object, Request): code = self.object.code - statusses = list(code.status.status_dict.values()) - if len(statusses) != 1: + approval_decisions = list(code.status.status_dict.values()) + if len(approval_decisions) != 1: raise ValueError("Request code should have exactly one status") - status_tuple = statusses[0] - status, _ = status_tuple - return status.value + return approval_decisions[0].status.value return "" # type: ignore def get_updated_by(self) -> str: @@ -114,7 +113,10 @@ def get_updated_by(self) -> str: user_view: UserView | None = None if isinstance(self.object, UserCode): - user_view = self.object.user + try: + user_view = self.object.user + except SyftException: + pass # nosec if isinstance(user_view, UserView): return f"Created by {user_view.email}" diff --git a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py index 258a698dcd5..998de0f3ad9 100644 --- a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py +++ b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py @@ -10,7 +10,10 @@ from syft.client.syncing import compare_clients from syft.client.syncing import resolve from syft.server.worker import Worker +from syft.service.code.user_code import ApprovalDecision +from syft.service.code.user_code import UserCodeStatus from syft.service.job.job_stash import Job +from syft.service.request.request import Request from syft.service.request.request import RequestStatus from syft.service.response import SyftSuccess from syft.service.sync.resolve_widget import ResolveWidget @@ -359,7 +362,6 @@ def compute() -> int: client_low_ds.code.compute(blocking=True) assert "waiting for approval" in exc.value.public_message - assert "PENDING" in exc.value.public_message assert low_client.requests[0].status == RequestStatus.PENDING @@ -381,7 +383,12 @@ def compute() -> int: diff_before, diff_after = compare_and_resolve( from_client=high_client, to_client=low_client, share_private_data=True ) - assert len(diff_before.batches) == 1 and diff_before.batches[0].root_type is Job + assert len(diff_before.batches) == 2 + root_types = [x.root_type for x in diff_before.batches] + assert Job in root_types + assert ( + Request in root_types + ) # we have not configured it to count UserCode as a root type assert low_client.requests[0].status == RequestStatus.APPROVED assert client_low_ds.code.compute().get() == 42 @@ -414,7 +421,10 @@ def compute() -> int: assert low_client.requests[0].status == RequestStatus.REJECTED # Un-deny. NOTE: not supported by current UX, this is just used to re-deny on high side - low_client.api.code.update(id=request_low.code_id, l0_deny_reason=None) + low_client.api.code_status.update( + id=request_low.status_id, + decision=ApprovalDecision(status=UserCodeStatus.PENDING), + ) assert low_client.requests[0].status == RequestStatus.PENDING # Sync request to high side diff --git a/tests/integration/local/twin_api_sync_test.py b/tests/integration/local/twin_api_sync_test.py index 9e8d28adaa7..f8c137e5967 100644 --- a/tests/integration/local/twin_api_sync_test.py +++ b/tests/integration/local/twin_api_sync_test.py @@ -136,8 +136,7 @@ def compute(query): endpoint_path="testapi.query", endpoint_timeout=expected_timeout_after ) widget = sy.sync(from_client=high_client, to_client=low_client) - result = widget[0].click_sync() - assert result, result + widget._sync_all() timeout_after = ( full_low_worker.python_server.services.api.stash.get_all(