diff --git a/packages/syft/src/syft/service/api/api_service.py b/packages/syft/src/syft/service/api/api_service.py index 8df26f0ac15..5f6ba0dbe48 100644 --- a/packages/syft/src/syft/service/api/api_service.py +++ b/packages/syft/src/syft/service/api/api_service.py @@ -72,7 +72,7 @@ def set( public_message="An API endpoint already exists at the given path." ) - result = self.stash.upsert(context.credentials, endpoint=new_endpoint).unwrap() + result = self.stash.upsert(context.credentials, obj=new_endpoint).unwrap() action_obj = ActionObject.from_obj( id=new_endpoint.action_object_id, syft_action_data=CustomEndpointActionObject(endpoint_id=result.id), @@ -157,7 +157,7 @@ def update( endpoint.mock_function.view_access = view_access # save changes - self.stash.upsert(context.credentials, endpoint=endpoint).unwrap() + self.stash.upsert(context.credentials, obj=endpoint).unwrap() return SyftSuccess(message="Endpoint successfully updated.") @service_method( @@ -218,7 +218,7 @@ def set_state( if mock and api_endpoint.mock_function: api_endpoint.mock_function.state = state - self.stash.upsert(context.credentials, endpoint=api_endpoint).unwrap() + self.stash.upsert(context.credentials, obj=api_endpoint).unwrap() return SyftSuccess(message=f"APIEndpoint {api_path} state updated.") @service_method( @@ -248,7 +248,7 @@ def set_settings( if mock and api_endpoint.mock_function: api_endpoint.mock_function.settings = settings - self.stash.upsert(context.credentials, endpoint=api_endpoint).unwrap() + self.stash.upsert(context.credentials, obj=api_endpoint).unwrap() return SyftSuccess(message=f"APIEndpoint {api_path} settings updated.") @service_method( diff --git a/packages/syft/src/syft/service/api/api_stash.py b/packages/syft/src/syft/service/api/api_stash.py index 0c0c6f73020..e892d48da61 100644 --- a/packages/syft/src/syft/service/api/api_stash.py +++ b/packages/syft/src/syft/service/api/api_stash.py @@ -33,22 +33,3 @@ def path_exists(self, credentials: SyftVerifyKey, path: str) -> bool: return True except NotFoundException: return False - - @as_result(StashException) - def upsert( - self, - credentials: SyftVerifyKey, - endpoint: TwinAPIEndpoint, - has_permission: bool = False, - ) -> TwinAPIEndpoint: - """Upsert an endpoint.""" - exists = self.path_exists(credentials=credentials, path=endpoint.path).unwrap() - - if exists: - super().delete_by_uid(credentials=credentials, uid=endpoint.id).unwrap() - - return ( - super() - .set(credentials=credentials, obj=endpoint, ignore_duplicates=False) - .unwrap() - ) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 85860821866..5323f3455c8 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -75,6 +75,8 @@ def with_session(func: Callable[P, T]) -> Callable[P, T]: # type: ignore """ Decorator to inject a session into the function kwargs if it is not provided. + Make sure to pass session as a keyword argument to the function. + TODO: This decorator is a temporary fix, we want to move to a DI approach instead: move db connection and session to context, and pass context to all stash methods. """ @@ -87,8 +89,9 @@ def with_session(func: Callable[P, T]) -> Callable[P, T]: # type: ignore def wrapper(self: "ObjectStash[StashT]", *args: Any, **kwargs: Any) -> Any: if inject_session and kwargs.get("session") is None: with self.sessionmaker() as session: - kwargs["session"] = session - return func(self, *args, **kwargs) + with session.begin(): + kwargs["session"] = session + return func(self, *args, **kwargs) return func(self, *args, **kwargs) return wrapper # type: ignore @@ -369,11 +372,13 @@ def set( uid = obj.id # check if the object already exists - if self.exists(credentials, uid) or not self.is_unique(obj): + if self.exists(credentials, uid, session=session) or not self.is_unique( + obj, session=session + ): if ignore_duplicates: return obj unique_fields_str = ", ".join(self.unique_fields) - raise StashException( + raise UniqueConstraintException( public_message=f"Duplication Key Error for {obj}.\n" f"The fields that should be unique are {unique_fields_str}." ) @@ -399,7 +404,6 @@ def set( raise StashException( f"Error serializing object: {e}. Some fields are invalid." ) - # create the object with the permissions stmt = self.table.insert().values( id=uid, @@ -408,7 +412,6 @@ def set( storage_permissions=storage_permissions, ) session.execute(stmt) - session.commit() return self.get_by_uid(credentials, uid, session=session).unwrap() @as_result(ValidationError, AttributeError) @@ -462,7 +465,7 @@ def update( ).unwrap() # TODO has_permission is not used - if not self.is_unique(obj): + if not self.is_unique(obj, session=session): raise UniqueConstraintException( f"Some fields are not unique for {type(obj).__name__} and unique fields {self.unique_fields}" ) @@ -483,14 +486,12 @@ def update( f"Error serializing object: {e}. Some fields are invalid." ) stmt = stmt.values(fields=fields) - result = session.execute(stmt) - session.commit() if result.rowcount == 0: raise NotFoundException( f"{self.object_type.__name__}: {obj.id} not found or no permission to update." ) - return self.get_by_uid(credentials, obj.id).unwrap() + return self.get_by_uid(credentials, obj.id, session=session).unwrap() @as_result(StashException, NotFoundException) @with_session @@ -510,7 +511,6 @@ def delete_by_uid( session=session, ) result = session.execute(stmt) - session.commit() if result.rowcount == 0: raise NotFoundException( f"{self.object_type.__name__}: {uid} not found or no permission to delete." @@ -649,8 +649,6 @@ def add_permission( stmt = self.table.update().where(self.table.c.id == permission.uid) stmt = stmt.values(permissions=list(existing_permissions)) session.execute(stmt) - session.commit() - return None @as_result(NotFoundException) @@ -685,7 +683,6 @@ def remove_permission( .values(permissions=list(permissions)) ) session.execute(stmt) - session.commit() return None @with_session @@ -842,7 +839,6 @@ def remove_storage_permission( .values(storage_permissions=[str(uid) for uid in permissions]) ) session.execute(stmt) - session.commit() return None @as_result(StashException) @@ -857,3 +853,26 @@ def _get_storage_permissions_for_uid( if result is None: raise NotFoundException(f"No storage permissions found for uid: {uid}") return {UID(uid) for uid in result.storage_permissions} + + @with_session + @as_result(StashException) + def upsert( + self, + credentials: SyftVerifyKey, + obj: StashT, + session: Session = None, + ) -> StashT: + """Insert or update an object in the stash if it already exists. + Atomic operation when using the same session for both operations. + """ + + try: + return self.set( + credentials=credentials, + obj=obj, + session=session, + ).unwrap() + except UniqueConstraintException: + return self.update( + credentials=credentials, obj=obj, session=session + ).unwrap() diff --git a/packages/syft/tests/syft/stores/base_stash_test.py b/packages/syft/tests/syft/stores/base_stash_test.py index 8ed9a312b86..5344ec8f5ab 100644 --- a/packages/syft/tests/syft/stores/base_stash_test.py +++ b/packages/syft/tests/syft/stores/base_stash_test.py @@ -190,6 +190,30 @@ def test_basestash_update( assert retrieved == updated_obj +def test_basestash_upsert( + root_verify_key, base_stash: MockStash, mock_object: MockObject, faker: Faker +) -> None: + base_stash.set(root_verify_key, mock_object).unwrap() + + updated_obj = mock_object.copy() + updated_obj.name = faker.name() + + retrieved = base_stash.upsert(root_verify_key, updated_obj).unwrap() + assert retrieved == updated_obj + + updated_obj.id = UID() + + with pytest.raises(StashException): + # fails because the name should be unique + base_stash.upsert(root_verify_key, updated_obj).unwrap() + + updated_obj.name = faker.name() + + retrieved = base_stash.upsert(root_verify_key, updated_obj).unwrap() + assert retrieved == updated_obj + assert len(base_stash.get_all(root_verify_key).unwrap()) == 2 + + def test_basestash_cannot_update_non_existent( root_verify_key, base_stash: MockStash, mock_object: MockObject, faker: Faker ) -> None: