diff --git a/app/db/auth.py b/app/db/auth.py index 5e69795ea..ab2e7343a 100644 --- a/app/db/auth.py +++ b/app/db/auth.py @@ -7,24 +7,46 @@ from sqlalchemy.orm import Query from app.db.model import Entity +from app.db.utils import get_declaring_class from app.schemas.auth import UserContext def constrain_to_accessible_entities[Q: Query | Select]( query: Q, - project_id: UUID4 | None, + user_context: UserContext | None, db_model_class: Any = Entity, ) -> Q: """Ensure a query is filtered to rows that are viewable by the user.""" - query = query.where( + if not user_context: # admin or global resource + return query + + # if model or alias has an authorized_project_id use it as is + if hasattr(db_model_class, "authorized_project_id"): + id_model_class = db_model_class + # otherwise look up the hierarchy to check if there is one defined there + else: + id_model_class = get_declaring_class(db_model_class, "authorized_project_id") + # global resource without authorized_project_id, always accessible + if not id_model_class: + return query + + # if user passes a specific project_id, use it to constrain resources + if user_context.project_id: + return query.where( + or_( + id_model_class.authorized_public == true(), + id_model_class.authorized_project_id == user_context.project_id, + ) + ) + + # otherwise use user_project_ids from token to check if user has access + return query.where( or_( - db_model_class.authorized_public == true(), - db_model_class.authorized_project_id == project_id if project_id else false(), + id_model_class.authorized_public == true(), + id_model_class.authorized_project_id.in_(user_context.user_project_ids), ) ) - return query - def constrain_to_private_entities[Q: Query | Select]( query: Q, @@ -32,12 +54,20 @@ def constrain_to_private_entities[Q: Query | Select]( db_model_class: Any = Entity, ) -> Q: """Ensure a query is filtered to private rows that are viewable by the user.""" + # if user passes a specific project_id, use it to constrain resources + if user_context.project_id: + return query.where( + and_( + db_model_class.authorized_public == false(), + db_model_class.authorized_project_id == user_context.project_id, + ) + ) + + # otherwise use project_ids from token to check if user has access return query.where( and_( db_model_class.authorized_public == false(), - db_model_class.authorized_project_id.in_(user_context.user_project_ids) - if user_context.user_project_ids - else false(), + db_model_class.authorized_project_id.in_(user_context.user_project_ids), ) ) diff --git a/app/queries/common.py b/app/queries/common.py index c6318c55c..2c9a6360b 100644 --- a/app/queries/common.py +++ b/app/queries/common.py @@ -39,7 +39,7 @@ def router_read_one[T: BaseModel, I: Identifiable]( id_: uuid.UUID, db: Session, db_model_class: type[I], - authorized_project_id: uuid.UUID | None, + user_context: UserContext | None, response_schema_class: type[T], apply_operations: ApplyOperations[I] | None, ) -> T: @@ -49,7 +49,7 @@ def router_read_one[T: BaseModel, I: Identifiable]( id_: id of the entity to read. db: database session. db_model_class: database model class. - authorized_project_id: id of the authorized project. + user_context: the user's context response_schema_class: Pydantic schema class for the returned data. apply_operations: transformer function that modifies the select query. @@ -57,12 +57,7 @@ def router_read_one[T: BaseModel, I: Identifiable]( the model data as a Pydantic model. """ query = sa.select(db_model_class).where(db_model_class.id == id_) - if authorized_project_id and ( - id_model_class := get_declaring_class(db_model_class, "authorized_project_id") - ): - query = constrain_to_accessible_entities( - query, authorized_project_id, db_model_class=id_model_class - ) + query = constrain_to_accessible_entities(query, user_context, db_model_class) if apply_operations: query = apply_operations(query) with ensure_result(error_message=f"{db_model_class.__name__} not found"): @@ -235,7 +230,7 @@ def router_read_many[T: BaseModel, I: Identifiable]( # noqa: PLR0913 *, db: Session, db_model_class: type[I], - authorized_project_id: uuid.UUID | None, + user_context: UserContext | None, with_search: Search[I] | None, with_in_brain_region: InBrainRegionQuery | None, facets: WithFacets | None, @@ -254,7 +249,7 @@ def router_read_many[T: BaseModel, I: Identifiable]( # noqa: PLR0913 Args: db: database session. db_model_class: database model class. - authorized_project_id: project id for filtering the resources. + user_context: the user's context with_search: search query (str). with_in_brain_region: enable family queries based on BrainRegion facets: facet query (bool). @@ -274,12 +269,7 @@ def router_read_many[T: BaseModel, I: Identifiable]( # noqa: PLR0913 the list of model data, pagination, and facets as a Pydantic model. """ filter_query = sa.select(db_model_class) - if id_model_class := get_declaring_class(db_model_class, "authorized_project_id"): - filter_query = constrain_to_accessible_entities( - filter_query, - project_id=authorized_project_id, - db_model_class=id_model_class, - ) + filter_query = constrain_to_accessible_entities(filter_query, user_context, db_model_class) if apply_filter_query_operations: filter_query = apply_filter_query_operations(filter_query) @@ -387,7 +377,7 @@ def router_delete_one[T: BaseModel, I: Identifiable]( id_: uuid.UUID, db: Session, db_model_class: type[I], - authorized_project_id: uuid.UUID | None, + user_context: UserContext | None, ) -> dict: """Delete a model from the database. @@ -395,15 +385,10 @@ def router_delete_one[T: BaseModel, I: Identifiable]( id_: id of the entity to read. db: database session. db_model_class: database model class. - authorized_project_id: project id for filtering the resources. + user_context: the user's context """ query = sa.select(db_model_class).where(db_model_class.id == id_) - if authorized_project_id and ( - id_model_class := get_declaring_class(db_model_class, "authorized_project_id") - ): - query = constrain_to_accessible_entities( - query, authorized_project_id, db_model_class=id_model_class - ) + query = constrain_to_accessible_entities(query, user_context, db_model_class) with ensure_result(error_message=f"{db_model_class.__name__} not found"): obj = db.execute(query).scalars().one() @@ -427,7 +412,7 @@ def router_update_activity_one[T: BaseModel, I: Activity]( id_: uuid.UUID, db: Session, db_model_class: type[I], - user_context: UserContext | UserContextWithProjectId, + user_context: UserContext, json_model: ActivityUpdate, response_schema_class: type[T], apply_operations: ApplyOperations | None = None, @@ -435,7 +420,7 @@ def router_update_activity_one[T: BaseModel, I: Activity]( query = sa.select(db_model_class).where(db_model_class.id == id_) if id_model_class := get_declaring_class(db_model_class, "authorized_project_id"): query = constrain_to_accessible_entities( - query, user_context.project_id, db_model_class=id_model_class + query, user_context=user_context, db_model_class=id_model_class ) if apply_operations: query = apply_operations(query) diff --git a/app/queries/entity.py b/app/queries/entity.py index b46cc04bd..da1b8f1d9 100644 --- a/app/queries/entity.py +++ b/app/queries/entity.py @@ -6,13 +6,14 @@ from app.db.auth import constrain_entity_query_to_project, constrain_to_accessible_entities from app.db.model import Entity from app.errors import ensure_result +from app.schemas.auth import UserContext def get_readable_entity[T: Entity]( db: Session, db_model_class: type[T], entity_id: uuid.UUID, - project_id: uuid.UUID | None, + user_context: UserContext | None, ) -> T: """Return a specific entity by type and id, readable by the given project. @@ -20,14 +21,14 @@ def get_readable_entity[T: Entity]( db: db session. db_model_class: Entity subclass. entity_id: id of the entity. - project_id: optional project id owning the entity. + user_context: optional user context Returns: the selected entity if it's public or owned by project_id, or raises NoResultFound if the entity doesn't exist, or it's forbidden. """ query = sa.select(db_model_class).where(db_model_class.id == entity_id) - query = constrain_to_accessible_entities(query, project_id=project_id) + query = constrain_to_accessible_entities(query, user_context=user_context) with ensure_result(f"Entity {db_model_class.__name__} {entity_id} not found or forbidden"): return db.execute(query).scalar_one() diff --git a/app/routers/admin.py b/app/routers/admin.py index 9102668a0..05756d87e 100644 --- a/app/routers/admin.py +++ b/app/routers/admin.py @@ -32,7 +32,7 @@ def delete_one( id_=id_, db=db, db_model_class=RESOURCE_TYPE_TO_CLASS[resource_type], - authorized_project_id=None, + user_context=None, ) diff --git a/app/routers/ion_channel_recording.py b/app/routers/ion_channel_recording.py index 80fc79d0c..4b2acb9a7 100644 --- a/app/routers/ion_channel_recording.py +++ b/app/routers/ion_channel_recording.py @@ -1,13 +1,20 @@ from fastapi import APIRouter import app.service.ion_channel_recording +from app.routers.admin import router as admin_router + +ROUTE = "ion-channel-recording" router = APIRouter( - prefix="/ion-channel-recording", - tags=["ion-channel-recording"], + prefix=f"/{ROUTE}", + tags=[ROUTE], ) read_many = router.get("")(app.service.ion_channel_recording.read_many) read_one = router.get("/{id_}")(app.service.ion_channel_recording.read_one) create_one = router.post("")(app.service.ion_channel_recording.create_one) update_one = router.patch("/{id_}")(app.service.ion_channel_recording.update_one) + +admin_read_one = admin_router.get(f"/{ROUTE}/{{id_}}")( + app.service.ion_channel_recording.admin_read_one +) diff --git a/app/service/admin.py b/app/service/admin.py index ebfade1d1..0394af419 100644 --- a/app/service/admin.py +++ b/app/service/admin.py @@ -49,7 +49,7 @@ def get_entity_assets( return router_read_many( db=repos.db, db_model_class=db_model_class, - authorized_project_id=None, + user_context=None, with_search=None, with_in_brain_region=None, facets=None, diff --git a/app/service/asset.py b/app/service/asset.py index 3a8b4ff26..ebdf9bf8f 100644 --- a/app/service/asset.py +++ b/app/service/asset.py @@ -61,7 +61,7 @@ def get_entity_assets( return router_read_many( db=repos.db, db_model_class=db_model_class, - authorized_project_id=user_context.project_id, + user_context=user_context, with_search=None, with_in_brain_region=None, facets=None, diff --git a/app/service/brain_atlas.py b/app/service/brain_atlas.py index d7bd709dc..a5ad4c46d 100644 --- a/app/service/brain_atlas.py +++ b/app/service/brain_atlas.py @@ -33,7 +33,7 @@ def read_many( return app.queries.common.router_read_many( db=db, db_model_class=BrainAtlas, - authorized_project_id=user_context.project_id, + user_context=user_context, with_search=None, with_in_brain_region=None, facets=None, @@ -52,7 +52,7 @@ def read_one(user_context: UserContextDep, atlas_id: uuid.UUID, db: SessionDep) id_=atlas_id, db=db, db_model_class=BrainAtlas, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=BrainAtlasRead, apply_operations=_load_brain_atlas, ) @@ -63,7 +63,7 @@ def admin_read_one(db: SessionDep, atlas_id: uuid.UUID) -> BrainAtlasRead: id_=atlas_id, db=db, db_model_class=BrainAtlas, - authorized_project_id=None, + user_context=None, response_schema_class=BrainAtlasRead, apply_operations=_load_brain_atlas, ) @@ -79,7 +79,7 @@ def read_many_region( return app.queries.common.router_read_many( db=db, db_model_class=BrainAtlasRegion, - authorized_project_id=user_context.project_id, + user_context=user_context, with_search=None, with_in_brain_region=None, facets=None, @@ -102,7 +102,7 @@ def read_one_region( id_=atlas_region_id, db=db, db_model_class=BrainAtlasRegion, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=BrainAtlasRegionRead, apply_operations=lambda select: select.filter( BrainAtlasRegion.brain_atlas_id == atlas_id diff --git a/app/service/brain_region.py b/app/service/brain_region.py index c14cee032..93cba50f9 100644 --- a/app/service/brain_region.py +++ b/app/service/brain_region.py @@ -28,7 +28,7 @@ def read_many( return app.queries.common.router_read_many( db=db, db_model_class=BrainRegion, - authorized_project_id=None, + user_context=None, with_search=None, with_in_brain_region=None, facets=None, diff --git a/app/service/brain_region_hierarchy.py b/app/service/brain_region_hierarchy.py index 064f9ec1d..242fe484a 100644 --- a/app/service/brain_region_hierarchy.py +++ b/app/service/brain_region_hierarchy.py @@ -31,7 +31,7 @@ def read_many( return app.queries.common.router_read_many( db=db, db_model_class=BrainRegionHierarchy, - authorized_project_id=None, + user_context=None, with_search=None, with_in_brain_region=None, facets=None, @@ -50,7 +50,7 @@ def read_one(id_: uuid.UUID, db: SessionDep) -> BrainRegionHierarchyRead: id_=id_, db=db, db_model_class=BrainRegionHierarchy, - authorized_project_id=None, + user_context=None, response_schema_class=BrainRegionHierarchyRead, apply_operations=_load, ) diff --git a/app/service/calibration.py b/app/service/calibration.py index a2271867f..a45485a76 100644 --- a/app/service/calibration.py +++ b/app/service/calibration.py @@ -52,7 +52,7 @@ def read_one( db=db, id_=id_, db_model_class=Calibration, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=CalibrationRead, apply_operations=_load, ) @@ -66,7 +66,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=Calibration, - authorized_project_id=None, + user_context=None, response_schema_class=CalibrationRead, apply_operations=_load, ) @@ -137,13 +137,13 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=CalibrationRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) def delete_one( - user_context: UserContextWithProjectIdDep, + user_context: UserContextDep, db: SessionDep, id_: uuid.UUID, ) -> CalibrationRead: @@ -151,7 +151,7 @@ def delete_one( id_=id_, db=db, db_model_class=Calibration, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=CalibrationRead, apply_operations=_load, ) @@ -159,7 +159,7 @@ def delete_one( id_=id_, db=db, db_model_class=Calibration, - authorized_project_id=None, # already validated + user_context=None, # already validated ) return one @@ -168,7 +168,7 @@ def update_one( db: SessionDep, id_: uuid.UUID, json_model: CalibrationUpdate, # pyright: ignore [reportInvalidTypeForm] - user_context: UserContextWithProjectIdDep, + user_context: UserContextDep, ) -> CalibrationRead: return router_update_activity_one( db=db, diff --git a/app/service/cell_composition.py b/app/service/cell_composition.py index 4c9a16c8c..5e8dca1af 100644 --- a/app/service/cell_composition.py +++ b/app/service/cell_composition.py @@ -44,7 +44,7 @@ def read_one( db=db, id_=id_, db_model_class=CellComposition, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=CellCompositionRead, apply_operations=_load, ) @@ -58,7 +58,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=CellComposition, - authorized_project_id=None, + user_context=None, response_schema_class=CellCompositionRead, apply_operations=_load, ) @@ -84,6 +84,6 @@ def read_many( aliases={}, pagination_request=pagination_request, response_schema_class=CellCompositionRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=None, ) diff --git a/app/service/circuit.py b/app/service/circuit.py index 5467bc3cf..5a8438e45 100644 --- a/app/service/circuit.py +++ b/app/service/circuit.py @@ -61,7 +61,7 @@ def read_one( db=db, id_=id_, db_model_class=Circuit, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=CircuitRead, apply_operations=_load, ) @@ -75,7 +75,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=Circuit, - authorized_project_id=None, + user_context=None, response_schema_class=CircuitRead, apply_operations=_load, ) @@ -158,6 +158,6 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=CircuitRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) diff --git a/app/service/consortium.py b/app/service/consortium.py index 635849a29..6e5520982 100644 --- a/app/service/consortium.py +++ b/app/service/consortium.py @@ -52,7 +52,7 @@ def read_many( return app.queries.common.router_read_many( db=db, db_model_class=Consortium, - authorized_project_id=None, + user_context=None, with_search=None, with_in_brain_region=None, facets=None, @@ -72,21 +72,14 @@ def read_one(id_: uuid.UUID, db: SessionDep) -> ConsortiumRead: id_=id_, db=db, db_model_class=Consortium, - authorized_project_id=None, + user_context=None, response_schema_class=ConsortiumRead, apply_operations=_load, ) -def admin_read_one(db: SessionDep, id_: uuid.UUID) -> ConsortiumRead: - return app.queries.common.router_read_one( - id_=id_, - db=db, - db_model_class=Consortium, - authorized_project_id=None, - response_schema_class=ConsortiumRead, - apply_operations=_load, - ) +# global resource +admin_read_one = read_one def create_one( diff --git a/app/service/contribution.py b/app/service/contribution.py index ae141aafd..42f66b6a8 100644 --- a/app/service/contribution.py +++ b/app/service/contribution.py @@ -61,7 +61,7 @@ def read_many( aliases=aliases, ) - filter_query = lambda q: constrain_to_accessible_entities(_load(q), user_context.project_id) + filter_query = lambda q: constrain_to_accessible_entities(_load(q), user_context) return app.queries.common.router_read_many( db=db, @@ -76,7 +76,7 @@ def read_many( response_schema_class=ContributionRead, name_to_facet_query_params=name_to_facet_query_params, filter_model=filter_model, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) @@ -90,10 +90,10 @@ def read_one( id_=id_, db=db, db_model_class=Contribution, - authorized_project_id=None, + user_context=None, response_schema_class=ContributionRead, apply_operations=lambda q: constrain_to_accessible_entities( - _load(q), user_context.project_id + _load(q), user_context=user_context ), ) @@ -106,7 +106,7 @@ def admin_read_one( id_=id_, db=db, db_model_class=Contribution, - authorized_project_id=None, + user_context=None, response_schema_class=ContributionRead, apply_operations=_load, ) diff --git a/app/service/derivation.py b/app/service/derivation.py index 341030acc..db0997e23 100644 --- a/app/service/derivation.py +++ b/app/service/derivation.py @@ -58,7 +58,7 @@ def apply_filter_query_operations(q): return router_read_many( db=db, db_model_class=db_model_class, - authorized_project_id=user_context.project_id, + user_context=user_context, with_search=None, with_in_brain_region=None, facets=None, diff --git a/app/service/electrical_cell_recording.py b/app/service/electrical_cell_recording.py index 9a87a9143..153e7c8ed 100644 --- a/app/service/electrical_cell_recording.py +++ b/app/service/electrical_cell_recording.py @@ -67,7 +67,7 @@ def read_one( db=db, id_=id_, db_model_class=ElectricalCellRecording, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=ElectricalCellRecordingRead, apply_operations=_load, ) @@ -81,7 +81,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=ElectricalCellRecording, - authorized_project_id=None, + user_context=None, response_schema_class=ElectricalCellRecordingRead, apply_operations=_load, ) @@ -162,7 +162,7 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=ElectricalCellRecordingRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) diff --git a/app/service/electrical_recording_stimulus.py b/app/service/electrical_recording_stimulus.py index ddc13537a..baba13d1c 100644 --- a/app/service/electrical_recording_stimulus.py +++ b/app/service/electrical_recording_stimulus.py @@ -52,7 +52,7 @@ def read_one( db=db, id_=id_, db_model_class=ElectricalRecordingStimulus, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=ElectricalRecordingStimulusRead, apply_operations=_load, ) @@ -66,7 +66,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=ElectricalRecordingStimulus, - authorized_project_id=None, + user_context=None, response_schema_class=ElectricalRecordingStimulusRead, apply_operations=_load, ) @@ -149,6 +149,6 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=ElectricalRecordingStimulusRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) diff --git a/app/service/em_cell_mesh.py b/app/service/em_cell_mesh.py index a8b6c0ca8..47cf44f95 100644 --- a/app/service/em_cell_mesh.py +++ b/app/service/em_cell_mesh.py @@ -83,7 +83,7 @@ def read_many( return router_read_many( db=db, db_model_class=EMCellMesh, - authorized_project_id=user_context.project_id, + user_context=user_context, with_search=with_search, with_in_brain_region=in_brain_region, facets=facets, @@ -107,7 +107,7 @@ def read_one( id_=id_, db=db, db_model_class=EMCellMesh, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=EMCellMeshRead, apply_operations=_load, ) diff --git a/app/service/em_dense_reconstruction_dataset.py b/app/service/em_dense_reconstruction_dataset.py index 849454c79..a339260c2 100644 --- a/app/service/em_dense_reconstruction_dataset.py +++ b/app/service/em_dense_reconstruction_dataset.py @@ -74,7 +74,7 @@ def read_many( return router_read_many( db=db, db_model_class=EMDenseReconstructionDataset, - authorized_project_id=user_context.project_id, + user_context=user_context, with_search=with_search, with_in_brain_region=in_brain_region, facets=facets, @@ -98,7 +98,7 @@ def read_one( id_=id_, db=db, db_model_class=EMDenseReconstructionDataset, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=EMDenseReconstructionDatasetRead, apply_operations=_load, ) diff --git a/app/service/emodel.py b/app/service/emodel.py index 7f2c9af35..cceea66e4 100644 --- a/app/service/emodel.py +++ b/app/service/emodel.py @@ -69,7 +69,7 @@ def read_one( id_=id_, db=db, db_model_class=EModel, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=EModelReadExpanded, apply_operations=_load, ) @@ -83,7 +83,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=EModel, - authorized_project_id=None, + user_context=None, response_schema_class=EModelReadExpanded, apply_operations=_load, ) @@ -162,7 +162,7 @@ def read_many( return router_read_many( db=db, db_model_class=EModel, - authorized_project_id=user_context.project_id, + user_context=user_context, with_search=with_search, with_in_brain_region=in_brain_region, facets=facets, diff --git a/app/service/entity.py b/app/service/entity.py index e261953ab..576bde617 100644 --- a/app/service/entity.py +++ b/app/service/entity.py @@ -31,7 +31,7 @@ def get_readable_entity( db=repos.db, db_model_class=db_model_class, entity_id=entity_id, - project_id=user_context.project_id, + user_context=user_context, ) @@ -88,7 +88,7 @@ def count_entities_by_type( sa.func.count(entity_class.id).label("count"), ) q = constrain_to_accessible_entities( - q, project_id=user_context.project_id, db_model_class=entity_class + q, user_context=user_context, db_model_class=entity_class ) q = q.join(brain_region_cte, entity_class.brain_region_id == brain_region_cte.c.id) # type: ignore[reportAttributeAccessIssue] @@ -104,7 +104,7 @@ def count_entities_by_type( Entity.type.label("type"), sa.func.count(Entity.id).label("count") ).select_from(Entity) query = constrain_to_accessible_entities( - query, project_id=user_context.project_id, db_model_class=Entity + query, user_context=user_context, db_model_class=Entity ) query = query.where(Entity.type.in_(entity_types)) query = query.group_by(Entity.type) diff --git a/app/service/etype.py b/app/service/etype.py index 14e4b1334..e9d3e8044 100644 --- a/app/service/etype.py +++ b/app/service/etype.py @@ -17,7 +17,7 @@ def read_many( return router_read_many( db=db, db_model_class=ETypeClass, - authorized_project_id=None, + user_context=None, with_search=None, with_in_brain_region=None, facets=None, @@ -36,7 +36,7 @@ def read_one(id_: uuid.UUID, db: SessionDep) -> ETypeClassRead: id_=id_, db=db, db_model_class=ETypeClass, - authorized_project_id=None, + user_context=None, response_schema_class=ETypeClassRead, apply_operations=None, ) diff --git a/app/service/etype_classification.py b/app/service/etype_classification.py index 474a00e9d..0615d7729 100644 --- a/app/service/etype_classification.py +++ b/app/service/etype_classification.py @@ -5,7 +5,7 @@ from fastapi import HTTPException from sqlalchemy.orm import aliased, joinedload, raiseload -from app.db.auth import constrain_to_accessible_entities +from app.db.auth import constrain_entity_query_to_project from app.db.model import ( Agent, Entity, @@ -45,9 +45,9 @@ def create_one( json_model: ETypeClassificationCreate, user_context: UserContextWithProjectIdDep, ) -> ETypeClassificationRead: - stmt = constrain_to_accessible_entities( + stmt = constrain_entity_query_to_project( sa.select(sa.func.count(Entity.id)).where(Entity.id == json_model.entity_id), - user_context.project_id, + project_id=user_context.project_id, ) if db.execute(stmt).scalar_one() == 0: L.warning("Attempting to create an annotation for an entity inaccessible to user") @@ -78,7 +78,7 @@ def read_one( db=db, id_=id_, db_model_class=ETypeClassification, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=ETypeClassificationRead, apply_operations=_load, ) @@ -92,7 +92,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=ETypeClassification, - authorized_project_id=None, + user_context=None, response_schema_class=ETypeClassificationRead, apply_operations=_load, ) @@ -140,7 +140,7 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=ETypeClassificationRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, with_in_brain_region=None, ) diff --git a/app/service/experimental_bouton_density.py b/app/service/experimental_bouton_density.py index 118d02ea4..8de36d44a 100644 --- a/app/service/experimental_bouton_density.py +++ b/app/service/experimental_bouton_density.py @@ -114,7 +114,7 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=ExperimentalBoutonDensityRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) @@ -128,7 +128,7 @@ def read_one( db=db, id_=id_, db_model_class=ExperimentalBoutonDensity, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=ExperimentalBoutonDensityRead, apply_operations=_load, ) @@ -142,7 +142,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=ExperimentalBoutonDensity, - authorized_project_id=None, + user_context=None, response_schema_class=ExperimentalBoutonDensityRead, apply_operations=_load, ) diff --git a/app/service/experimental_neuron_density.py b/app/service/experimental_neuron_density.py index 5474cf7b4..abd533f9a 100644 --- a/app/service/experimental_neuron_density.py +++ b/app/service/experimental_neuron_density.py @@ -117,7 +117,7 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=ExperimentalNeuronDensityRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) @@ -131,7 +131,7 @@ def read_one( db=db, id_=id_, db_model_class=ExperimentalNeuronDensity, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=ExperimentalNeuronDensityRead, apply_operations=_load, ) @@ -145,7 +145,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=ExperimentalNeuronDensity, - authorized_project_id=None, + user_context=None, response_schema_class=ExperimentalNeuronDensityRead, apply_operations=_load, ) diff --git a/app/service/experimental_synapses_per_connection.py b/app/service/experimental_synapses_per_connection.py index dc5e56f6b..e5c30bdaf 100644 --- a/app/service/experimental_synapses_per_connection.py +++ b/app/service/experimental_synapses_per_connection.py @@ -140,7 +140,7 @@ def read_many( apply_data_query_operations=_load, pagination_request=pagination_request, response_schema_class=ExperimentalSynapsesPerConnectionRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) @@ -154,7 +154,7 @@ def read_one( db=db, id_=id_, db_model_class=ExperimentalSynapsesPerConnection, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=ExperimentalSynapsesPerConnectionRead, apply_operations=_load, ) @@ -168,7 +168,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=ExperimentalSynapsesPerConnection, - authorized_project_id=None, + user_context=None, response_schema_class=ExperimentalSynapsesPerConnectionRead, apply_operations=_load, ) diff --git a/app/service/external_url.py b/app/service/external_url.py index 8636daca4..1cdc5cb2e 100644 --- a/app/service/external_url.py +++ b/app/service/external_url.py @@ -42,7 +42,7 @@ def read_one( db=db, id_=id_, db_model_class=ExternalUrl, - authorized_project_id=None, + user_context=None, response_schema_class=ExternalUrlRead, apply_operations=_load, ) @@ -53,7 +53,7 @@ def admin_read_one(db: SessionDep, id_: uuid.UUID) -> ExternalUrlRead: db=db, id_=id_, db_model_class=ExternalUrl, - authorized_project_id=None, + user_context=None, response_schema_class=ExternalUrlRead, apply_operations=_load, ) @@ -118,6 +118,6 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=ExternalUrlRead, - authorized_project_id=None, + user_context=None, filter_joins=filter_joins, ) diff --git a/app/service/hierarchy.py b/app/service/hierarchy.py index 6d40a3b22..0f087da19 100644 --- a/app/service/hierarchy.py +++ b/app/service/hierarchy.py @@ -11,12 +11,13 @@ from app.dependencies.auth import UserContextDep from app.dependencies.db import SessionDep from app.logger import L +from app.schemas.auth import UserContext from app.schemas.hierarchy import HierarchyNode, HierarchyTree def _load_nodes( db: Session, - project_id: uuid.UUID | None, + user_context: UserContext, entity_class: type[Entity], derivation_type: DerivationType, ) -> dict[uuid.UUID, HierarchyNode]: @@ -40,7 +41,7 @@ def _load_nodes( .order_by(*order_by) ) query_roots = constrain_to_accessible_entities( - query_roots, project_id=project_id, db_model_class=root + query_roots, user_context=user_context, db_model_class=root ) query_children = ( sa.select( @@ -57,10 +58,10 @@ def _load_nodes( .order_by(*order_by) ) query_children = constrain_to_accessible_entities( - query_children, project_id=project_id, db_model_class=parent + query_children, user_context=user_context, db_model_class=parent ) query_children = constrain_to_accessible_entities( - query_children, project_id=project_id, db_model_class=child + query_children, user_context=user_context, db_model_class=child ) query = query_roots.union_all(query_children) @@ -105,7 +106,7 @@ def read_circuit_hierarchy( """ all_nodes = _load_nodes( db, - project_id=user_context.project_id, + user_context=user_context, entity_class=Circuit, derivation_type=derivation_type, ) diff --git a/app/service/ion_channel.py b/app/service/ion_channel.py index bf64ba1a1..1ecaf9f25 100644 --- a/app/service/ion_channel.py +++ b/app/service/ion_channel.py @@ -42,7 +42,7 @@ def read_one( db=db, id_=id_, db_model_class=IonChannel, - authorized_project_id=None, + user_context=None, response_schema_class=IonChannelRead, apply_operations=_load, ) @@ -108,6 +108,6 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=IonChannelRead, - authorized_project_id=None, + user_context=None, filter_joins=filter_joins, ) diff --git a/app/service/ion_channel_model.py b/app/service/ion_channel_model.py index 035b77a73..45dd0ee22 100644 --- a/app/service/ion_channel_model.py +++ b/app/service/ion_channel_model.py @@ -89,7 +89,7 @@ def read_many( return router_read_many( db=db, db_model_class=IonChannelModel, - authorized_project_id=user_context.project_id, + user_context=user_context, with_search=with_search, with_in_brain_region=in_brain_region, facets=facets, @@ -113,7 +113,7 @@ def read_one( id_=id_, db=db, db_model_class=IonChannelModel, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=IonChannelModelExpanded, apply_operations=_load_expanded, ) @@ -127,7 +127,7 @@ def admin_read_one( id_=id_, db=db, db_model_class=IonChannelModel, - authorized_project_id=None, + user_context=None, response_schema_class=IonChannelModelExpanded, apply_operations=_load_expanded, ) diff --git a/app/service/ion_channel_recording.py b/app/service/ion_channel_recording.py index c0638dca2..3c862f510 100644 --- a/app/service/ion_channel_recording.py +++ b/app/service/ion_channel_recording.py @@ -67,7 +67,21 @@ def read_one( db=db, id_=id_, db_model_class=IonChannelRecording, - authorized_project_id=user_context.project_id, + user_context=user_context, + response_schema_class=IonChannelRecordingRead, + apply_operations=_load, + ) + + +def admin_read_one( + db: SessionDep, + id_: uuid.UUID, +) -> IonChannelRecordingRead: + return router_read_one( + db=db, + id_=id_, + db_model_class=IonChannelRecording, + user_context=None, response_schema_class=IonChannelRecordingRead, apply_operations=_load, ) @@ -148,7 +162,7 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=IonChannelRecordingRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) diff --git a/app/service/license.py b/app/service/license.py index e1177097e..ad5bf12f5 100644 --- a/app/service/license.py +++ b/app/service/license.py @@ -16,7 +16,7 @@ def read_many( return router_read_many( db=db, db_model_class=License, - authorized_project_id=None, + user_context=None, with_search=None, with_in_brain_region=None, facets=None, @@ -35,7 +35,7 @@ def read_one(id_: uuid.UUID, db: SessionDep) -> LicenseRead: id_=id_, db=db, db_model_class=License, - authorized_project_id=None, + user_context=None, response_schema_class=LicenseRead, apply_operations=None, ) diff --git a/app/service/measurement_annotation.py b/app/service/measurement_annotation.py index ebbac1c54..8010f7412 100644 --- a/app/service/measurement_annotation.py +++ b/app/service/measurement_annotation.py @@ -54,7 +54,7 @@ def read_many( ) -> ListResponse[MeasurementAnnotationRead]: apply_filter_query_operations = lambda q: constrain_to_accessible_entities( q.join(Entity, Entity.id == MeasurementAnnotation.entity_id), - project_id=user_context.project_id, + user_context=user_context, ) facet_keys = [] filter_keys = [ @@ -70,7 +70,7 @@ def read_many( return router_read_many( db=db, db_model_class=MeasurementAnnotation, - authorized_project_id=None, # validated with apply_filter_query_operations + user_context=None, # validated with apply_filter_query_operations with_search=None, with_in_brain_region=in_brain_region, facets=None, @@ -92,14 +92,14 @@ def read_one( ) -> MeasurementAnnotationRead: def apply_operations(q): q = q.join(Entity, Entity.id == MeasurementAnnotation.entity_id) - q = constrain_to_accessible_entities(q, project_id=user_context.project_id) + q = constrain_to_accessible_entities(q, user_context=user_context) return _load_from_db(q=q) return router_read_one( id_=id_, db=db, db_model_class=MeasurementAnnotation, - authorized_project_id=None, # validated with apply_operations + user_context=None, # validated with apply_operations response_schema_class=MeasurementAnnotationRead, apply_operations=apply_operations, ) @@ -113,7 +113,7 @@ def admin_read_one( id_=id_, db=db, db_model_class=MeasurementAnnotation, - authorized_project_id=None, + user_context=None, response_schema_class=MeasurementAnnotationRead, apply_operations=_load_from_db, ) @@ -160,7 +160,7 @@ def apply_operations(q): id_=id_, db=db, db_model_class=MeasurementAnnotation, - authorized_project_id=None, # validated with apply_operations + user_context=None, # validated with apply_operations response_schema_class=MeasurementAnnotationRead, apply_operations=apply_operations, ) @@ -168,6 +168,6 @@ def apply_operations(q): id_=id_, db=db, db_model_class=MeasurementAnnotation, - authorized_project_id=None, # already validated + user_context=None, # already validated ) return one diff --git a/app/service/memodel.py b/app/service/memodel.py index fac7c2a8a..ad59c2678 100644 --- a/app/service/memodel.py +++ b/app/service/memodel.py @@ -87,7 +87,7 @@ def read_one(db: SessionDep, id_: uuid.UUID, user_context: UserContextDep) -> ME id_=id_, db=db, db_model_class=MEModel, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=MEModelRead, apply_operations=_load, ) @@ -98,7 +98,7 @@ def admin_read_one(db: SessionDep, id_: uuid.UUID) -> MEModelRead: id_=id_, db=db, db_model_class=MEModel, - authorized_project_id=None, + user_context=None, response_schema_class=MEModelRead, apply_operations=_load, ) @@ -182,7 +182,7 @@ def read_many( return router_read_many( db=db, db_model_class=MEModel, - authorized_project_id=user_context.project_id, + user_context=user_context, with_search=search, with_in_brain_region=in_brain_region, facets=facets, diff --git a/app/service/memodel_calibration_result.py b/app/service/memodel_calibration_result.py index 21db781a8..39af39c94 100644 --- a/app/service/memodel_calibration_result.py +++ b/app/service/memodel_calibration_result.py @@ -42,7 +42,7 @@ def read_one( db=db, id_=id_, db_model_class=MEModelCalibrationResult, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=MEModelCalibrationResultRead, apply_operations=_load, ) @@ -56,7 +56,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=MEModelCalibrationResult, - authorized_project_id=None, + user_context=None, response_schema_class=MEModelCalibrationResultRead, apply_operations=_load, ) @@ -117,6 +117,6 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=MEModelCalibrationResultRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=None, ) diff --git a/app/service/morphology.py b/app/service/morphology.py index c173a51f4..357c45223 100644 --- a/app/service/morphology.py +++ b/app/service/morphology.py @@ -89,7 +89,7 @@ def read_one( id_=id_, db=db, db_model_class=ReconstructionMorphology, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=response_schema_class, apply_operations=apply_operations, ) @@ -103,7 +103,7 @@ def admin_read_one( id_=id_, db=db, db_model_class=ReconstructionMorphology, - authorized_project_id=None, + user_context=None, response_schema_class=ReconstructionMorphologyRead, apply_operations=_load_from_db, ) @@ -185,7 +185,7 @@ def read_many( return router_read_many( db=db, db_model_class=ReconstructionMorphology, - authorized_project_id=user_context.project_id, + user_context=user_context, with_search=search, with_in_brain_region=in_brain_region, facets=with_facets, diff --git a/app/service/mtype.py b/app/service/mtype.py index 6d4c1a98f..3001985df 100644 --- a/app/service/mtype.py +++ b/app/service/mtype.py @@ -17,7 +17,7 @@ def read_many( return router_read_many( db=db, db_model_class=MTypeClass, - authorized_project_id=None, + user_context=None, with_search=None, with_in_brain_region=None, facets=None, @@ -36,7 +36,7 @@ def read_one(id_: uuid.UUID, db: SessionDep) -> MTypeClassRead: id_=id_, db=db, db_model_class=MTypeClass, - authorized_project_id=None, + user_context=None, response_schema_class=MTypeClassRead, apply_operations=None, ) diff --git a/app/service/mtype_classification.py b/app/service/mtype_classification.py index f8d945f00..aa6355a1e 100644 --- a/app/service/mtype_classification.py +++ b/app/service/mtype_classification.py @@ -5,7 +5,7 @@ from fastapi import HTTPException from sqlalchemy.orm import aliased, joinedload, raiseload -from app.db.auth import constrain_to_accessible_entities +from app.db.auth import constrain_entity_query_to_project from app.db.model import ( Agent, Entity, @@ -45,9 +45,10 @@ def create_one( json_model: MTypeClassificationCreate, user_context: UserContextWithProjectIdDep, ) -> MTypeClassificationRead: - stmt = constrain_to_accessible_entities( + # allow entities that are in the same project as the classification + stmt = constrain_entity_query_to_project( sa.select(sa.func.count(Entity.id)).where(Entity.id == json_model.entity_id), - user_context.project_id, + project_id=user_context.project_id, ) if db.execute(stmt).scalar_one() == 0: L.warning("Attempting to create an annotation for an entity inaccessible to user") @@ -79,7 +80,7 @@ def read_one( db=db, id_=id_, db_model_class=MTypeClassification, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=MTypeClassificationRead, apply_operations=_load, ) @@ -93,7 +94,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=MTypeClassification, - authorized_project_id=None, + user_context=None, response_schema_class=MTypeClassificationRead, apply_operations=_load, ) @@ -141,7 +142,7 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=MTypeClassificationRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, with_in_brain_region=None, ) diff --git a/app/service/organization.py b/app/service/organization.py index 6129d0832..e82454293 100644 --- a/app/service/organization.py +++ b/app/service/organization.py @@ -52,7 +52,7 @@ def read_many( return app.queries.common.router_read_many( db=db, db_model_class=Organization, - authorized_project_id=None, + user_context=None, with_search=None, with_in_brain_region=None, facets=None, @@ -72,21 +72,14 @@ def read_one(id_: uuid.UUID, db: SessionDep) -> OrganizationRead: id_=id_, db=db, db_model_class=Organization, - authorized_project_id=None, + user_context=None, response_schema_class=OrganizationRead, apply_operations=_load, ) -def admin_read_one(db: SessionDep, id_: uuid.UUID) -> OrganizationRead: - return app.queries.common.router_read_one( - id_=id_, - db=db, - db_model_class=Organization, - authorized_project_id=None, - response_schema_class=OrganizationRead, - apply_operations=_load, - ) +# global resource +admin_read_one = read_one def create_one( diff --git a/app/service/person.py b/app/service/person.py index 902bce10c..5df8bec69 100644 --- a/app/service/person.py +++ b/app/service/person.py @@ -57,7 +57,7 @@ def read_many( return router_read_many( db=db, db_model_class=Person, - authorized_project_id=None, + user_context=None, with_search=None, with_in_brain_region=None, facets=None, @@ -77,21 +77,14 @@ def read_one(id_: uuid.UUID, db: SessionDep) -> PersonRead: id_=id_, db=db, db_model_class=Person, - authorized_project_id=None, + user_context=None, response_schema_class=PersonRead, apply_operations=_load, ) -def admin_read_one(db: SessionDep, id_: uuid.UUID) -> PersonRead: - return router_read_one( - id_=id_, - db=db, - db_model_class=Person, - authorized_project_id=None, - response_schema_class=PersonRead, - apply_operations=_load, - ) +# global resource +admin_read_one = read_one def create_one(person: PersonCreate, db: SessionDep, user_context: AdminContextDep) -> PersonRead: diff --git a/app/service/publication.py b/app/service/publication.py index 2ef761ee0..ad51ff1c7 100644 --- a/app/service/publication.py +++ b/app/service/publication.py @@ -42,7 +42,7 @@ def read_one( db=db, id_=id_, db_model_class=Publication, - authorized_project_id=None, + user_context=None, response_schema_class=PublicationRead, apply_operations=_load, ) @@ -53,7 +53,7 @@ def admin_read_one(db: SessionDep, id_: uuid.UUID) -> PublicationRead: db=db, id_=id_, db_model_class=Publication, - authorized_project_id=None, + user_context=None, response_schema_class=PublicationRead, apply_operations=_load, ) @@ -118,6 +118,6 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=PublicationRead, - authorized_project_id=None, + user_context=None, filter_joins=filter_joins, ) diff --git a/app/service/role.py b/app/service/role.py index 0cdbdc616..f04e34458 100644 --- a/app/service/role.py +++ b/app/service/role.py @@ -16,7 +16,7 @@ def read_many( return router_read_many( db=db, db_model_class=Role, - authorized_project_id=None, + user_context=None, with_search=None, with_in_brain_region=None, facets=None, @@ -35,7 +35,7 @@ def read_one(id_: uuid.UUID, db: SessionDep) -> RoleRead: id_=id_, db=db, db_model_class=Role, - authorized_project_id=None, + user_context=None, response_schema_class=RoleRead, apply_operations=None, ) diff --git a/app/service/scientific_artifact_external_url_link.py b/app/service/scientific_artifact_external_url_link.py index 0cc35744f..1a798dae7 100644 --- a/app/service/scientific_artifact_external_url_link.py +++ b/app/service/scientific_artifact_external_url_link.py @@ -77,7 +77,7 @@ def read_one( db=db, id_=id_, db_model_class=ScientificArtifactExternalUrlLink, - authorized_project_id=None, + user_context=None, response_schema_class=ScientificArtifactExternalUrlLinkRead, apply_operations=_load, ) @@ -93,7 +93,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=ScientificArtifactExternalUrlLink, - authorized_project_id=None, + user_context=None, response_schema_class=ScientificArtifactExternalUrlLinkRead, apply_operations=_load, ) @@ -173,7 +173,7 @@ def read_many( updated_by_alias, ScientificArtifactExternalUrlLink.updated_by_id == updated_by_alias.id, ), - project_id=user_context.project_id, + user_context=user_context, db_model_class=scientific_artifact_alias, ) @@ -192,6 +192,6 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=ScientificArtifactExternalUrlLinkRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) diff --git a/app/service/scientific_artifact_publication_link.py b/app/service/scientific_artifact_publication_link.py index e2f509f24..2909d20d5 100644 --- a/app/service/scientific_artifact_publication_link.py +++ b/app/service/scientific_artifact_publication_link.py @@ -75,7 +75,7 @@ def read_one( db=db, id_=id_, db_model_class=ScientificArtifactPublicationLink, - authorized_project_id=None, + user_context=None, response_schema_class=ScientificArtifactPublicationLinkRead, apply_operations=_load, ) @@ -91,7 +91,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=ScientificArtifactPublicationLink, - authorized_project_id=None, + user_context=None, response_schema_class=ScientificArtifactPublicationLinkRead, apply_operations=_load, ) @@ -171,10 +171,9 @@ def read_many( updated_by_alias, ScientificArtifactPublicationLink.updated_by_id == updated_by_alias.id, ), - project_id=user_context.project_id, + user_context=user_context, db_model_class=scientific_artifact_alias, ) - load_with_aliases = lambda q: _load_with_eager(q, aliases) return router_read_many( @@ -190,6 +189,6 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=ScientificArtifactPublicationLinkRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) diff --git a/app/service/simulation.py b/app/service/simulation.py index 6db63ec31..45da34b04 100644 --- a/app/service/simulation.py +++ b/app/service/simulation.py @@ -54,7 +54,7 @@ def read_one( db=db, id_=id_, db_model_class=Simulation, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=SimulationRead, apply_operations=_load, ) @@ -68,7 +68,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=Simulation, - authorized_project_id=None, + user_context=None, response_schema_class=SimulationRead, apply_operations=_load, ) @@ -150,6 +150,6 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=SimulationRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) diff --git a/app/service/simulation_campaign.py b/app/service/simulation_campaign.py index cf3e08466..40136d63b 100644 --- a/app/service/simulation_campaign.py +++ b/app/service/simulation_campaign.py @@ -57,7 +57,7 @@ def read_one( db=db, id_=id_, db_model_class=SimulationCampaign, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=SimulationCampaignRead, apply_operations=_load, ) @@ -71,7 +71,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=SimulationCampaign, - authorized_project_id=None, + user_context=None, response_schema_class=SimulationCampaignRead, apply_operations=_load, ) @@ -158,6 +158,6 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=SimulationCampaignRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) diff --git a/app/service/simulation_execution.py b/app/service/simulation_execution.py index 3ec985606..2d3e7ebf5 100644 --- a/app/service/simulation_execution.py +++ b/app/service/simulation_execution.py @@ -52,7 +52,7 @@ def read_one( db=db, id_=id_, db_model_class=SimulationExecution, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=SimulationExecutionRead, apply_operations=_load, ) @@ -66,7 +66,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=SimulationExecution, - authorized_project_id=None, + user_context=None, response_schema_class=SimulationExecutionRead, apply_operations=_load, ) @@ -137,13 +137,13 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=SimulationExecutionRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) def delete_one( - user_context: UserContextWithProjectIdDep, + user_context: UserContextDep, db: SessionDep, id_: uuid.UUID, ) -> SimulationExecutionRead: @@ -151,7 +151,7 @@ def delete_one( id_=id_, db=db, db_model_class=SimulationExecution, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=SimulationExecutionRead, apply_operations=_load, ) @@ -159,7 +159,7 @@ def delete_one( id_=id_, db=db, db_model_class=SimulationExecution, - authorized_project_id=None, # already validated + user_context=None, # already validated ) return one @@ -168,7 +168,7 @@ def update_one( db: SessionDep, id_: uuid.UUID, json_model: SimulationExecutionUpdate, # pyright: ignore [reportInvalidTypeForm] - user_context: UserContextWithProjectIdDep, + user_context: UserContextDep, ) -> SimulationExecutionRead: return router_update_activity_one( db=db, diff --git a/app/service/simulation_generation.py b/app/service/simulation_generation.py index 85b5f12a1..b5f125a9a 100644 --- a/app/service/simulation_generation.py +++ b/app/service/simulation_generation.py @@ -52,7 +52,7 @@ def read_one( db=db, id_=id_, db_model_class=SimulationGeneration, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=SimulationGenerationRead, apply_operations=_load, ) @@ -66,7 +66,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=SimulationGeneration, - authorized_project_id=None, + user_context=None, response_schema_class=SimulationGenerationRead, apply_operations=_load, ) @@ -132,13 +132,13 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=SimulationGenerationRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) def delete_one( - user_context: UserContextWithProjectIdDep, + user_context: UserContextDep, db: SessionDep, id_: uuid.UUID, ) -> SimulationGenerationRead: @@ -146,7 +146,7 @@ def delete_one( id_=id_, db=db, db_model_class=SimulationGeneration, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=SimulationGenerationRead, apply_operations=_load, ) @@ -154,7 +154,7 @@ def delete_one( id_=id_, db=db, db_model_class=SimulationGeneration, - authorized_project_id=None, # already validated + user_context=None, # already validated ) return one @@ -163,7 +163,7 @@ def update_one( db: SessionDep, id_: uuid.UUID, json_model: SimulationGenerationUpdate, # pyright: ignore [reportInvalidTypeForm] - user_context: UserContextWithProjectIdDep, + user_context: UserContextDep, ) -> SimulationGenerationRead: return router_update_activity_one( db=db, diff --git a/app/service/simulation_result.py b/app/service/simulation_result.py index 5f12e5b75..8f15af797 100644 --- a/app/service/simulation_result.py +++ b/app/service/simulation_result.py @@ -53,7 +53,7 @@ def read_one( db=db, id_=id_, db_model_class=SimulationResult, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=SimulationResultRead, apply_operations=_load, ) @@ -67,7 +67,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=SimulationResult, - authorized_project_id=None, + user_context=None, response_schema_class=SimulationResultRead, apply_operations=_load, ) @@ -149,6 +149,6 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=SimulationResultRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) diff --git a/app/service/single_neuron_simulation.py b/app/service/single_neuron_simulation.py index 007b2a4a8..3879b6e67 100644 --- a/app/service/single_neuron_simulation.py +++ b/app/service/single_neuron_simulation.py @@ -49,7 +49,7 @@ def read_one( db=db, id_=id_, db_model_class=SingleNeuronSimulation, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=SingleNeuronSimulationRead, apply_operations=_load, ) @@ -63,7 +63,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=SingleNeuronSimulation, - authorized_project_id=None, + user_context=None, response_schema_class=SingleNeuronSimulationRead, apply_operations=_load, ) @@ -148,6 +148,6 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=SingleNeuronSimulationRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) diff --git a/app/service/single_neuron_synaptome.py b/app/service/single_neuron_synaptome.py index 61e2ea654..45500dac8 100644 --- a/app/service/single_neuron_synaptome.py +++ b/app/service/single_neuron_synaptome.py @@ -51,7 +51,7 @@ def read_one( db=db, id_=id_, db_model_class=SingleNeuronSynaptome, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=SingleNeuronSynaptomeRead, apply_operations=_load, ) @@ -65,7 +65,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=SingleNeuronSynaptome, - authorized_project_id=None, + user_context=None, response_schema_class=SingleNeuronSynaptomeRead, apply_operations=_load, ) @@ -150,6 +150,6 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=SingleNeuronSynaptomeRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) diff --git a/app/service/single_neuron_synaptome_simulation.py b/app/service/single_neuron_synaptome_simulation.py index 2ff346663..3776cda4c 100644 --- a/app/service/single_neuron_synaptome_simulation.py +++ b/app/service/single_neuron_synaptome_simulation.py @@ -58,7 +58,7 @@ def read_one( return router_read_one( db=db, id_=id_, - authorized_project_id=user_context.project_id, + user_context=user_context, db_model_class=SingleNeuronSynaptomeSimulation, response_schema_class=SingleNeuronSynaptomeSimulationRead, apply_operations=_load, @@ -72,7 +72,7 @@ def admin_read_one( return router_read_one( db=db, id_=id_, - authorized_project_id=None, + user_context=None, db_model_class=SingleNeuronSynaptomeSimulation, response_schema_class=SingleNeuronSynaptomeSimulationRead, apply_operations=_load, @@ -158,6 +158,6 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=SingleNeuronSynaptomeSimulationRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) diff --git a/app/service/species.py b/app/service/species.py index ad8806904..95f2ae1cc 100644 --- a/app/service/species.py +++ b/app/service/species.py @@ -32,7 +32,7 @@ def read_one( id_=id_, db=db, db_model_class=Species, - authorized_project_id=None, + user_context=None, response_schema_class=SpeciesRead, apply_operations=_load, ) @@ -82,7 +82,7 @@ def read_many( return app.queries.common.router_read_many( db=db, db_model_class=Species, - authorized_project_id=None, + user_context=None, with_search=None, with_in_brain_region=None, facets=None, diff --git a/app/service/strain.py b/app/service/strain.py index 9ec5efc75..81c1fa87e 100644 --- a/app/service/strain.py +++ b/app/service/strain.py @@ -48,7 +48,7 @@ def read_many( return app.queries.common.router_read_many( db=db, db_model_class=Strain, - authorized_project_id=None, + user_context=None, with_search=None, with_in_brain_region=None, facets=None, @@ -69,7 +69,7 @@ def read_one(id_: uuid.UUID, db: SessionDep) -> StrainRead: id_=id_, db=db, db_model_class=Strain, - authorized_project_id=None, + user_context=None, response_schema_class=StrainRead, apply_operations=_load, ) diff --git a/app/service/subject.py b/app/service/subject.py index 29f6b1b23..122c4dcc4 100644 --- a/app/service/subject.py +++ b/app/service/subject.py @@ -38,7 +38,7 @@ def read_one( db=db, id_=id_, db_model_class=Subject, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=SubjectRead, apply_operations=_load, ) @@ -52,7 +52,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=Subject, - authorized_project_id=None, + user_context=None, response_schema_class=SubjectRead, apply_operations=_load, ) @@ -120,6 +120,6 @@ def read_many( aliases={}, pagination_request=pagination_request, response_schema_class=SubjectRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) diff --git a/app/service/validation.py b/app/service/validation.py index a33ace216..b40facf18 100644 --- a/app/service/validation.py +++ b/app/service/validation.py @@ -52,7 +52,7 @@ def read_one( db=db, id_=id_, db_model_class=Validation, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=ValidationRead, apply_operations=_load, ) @@ -66,7 +66,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=Validation, - authorized_project_id=None, + user_context=None, response_schema_class=ValidationRead, apply_operations=_load, ) @@ -137,13 +137,13 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=ValidationRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=filter_joins, ) def delete_one( - user_context: UserContextWithProjectIdDep, + user_context: UserContextDep, db: SessionDep, id_: uuid.UUID, ) -> ValidationRead: @@ -151,7 +151,7 @@ def delete_one( id_=id_, db=db, db_model_class=Validation, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=ValidationRead, apply_operations=_load, ) @@ -159,7 +159,7 @@ def delete_one( id_=id_, db=db, db_model_class=Validation, - authorized_project_id=None, # already validated + user_context=None, # already validated ) return one @@ -168,7 +168,7 @@ def update_one( db: SessionDep, id_: uuid.UUID, json_model: ValidationUpdate, # pyright: ignore [reportInvalidTypeForm] - user_context: UserContextWithProjectIdDep, + user_context: UserContextDep, ) -> ValidationRead: return router_update_activity_one( db=db, diff --git a/app/service/validation_result.py b/app/service/validation_result.py index ea67a15cf..39d8ea9f2 100644 --- a/app/service/validation_result.py +++ b/app/service/validation_result.py @@ -46,7 +46,7 @@ def read_one( db=db, id_=id_, db_model_class=ValidationResult, - authorized_project_id=user_context.project_id, + user_context=user_context, response_schema_class=ValidationResultRead, apply_operations=_load, ) @@ -60,7 +60,7 @@ def admin_read_one( db=db, id_=id_, db_model_class=ValidationResult, - authorized_project_id=None, + user_context=None, response_schema_class=ValidationResultRead, apply_operations=_load, ) @@ -122,6 +122,6 @@ def read_many( aliases=aliases, pagination_request=pagination_request, response_schema_class=ValidationResultRead, - authorized_project_id=user_context.project_id, + user_context=user_context, filter_joins=None, ) diff --git a/docs/authentication.md b/docs/authentication.md index 09c3cc0b1..30bd79c35 100644 --- a/docs/authentication.md +++ b/docs/authentication.md @@ -13,13 +13,20 @@ - **Write** endpoints require the user to be a member of the [service admin group](#service-admin-group). - Endpoints for [project resources](#endpoints-for-project-resources): - **Read** endpoints do not require `virtual-lab-id` or `project-id`, but: - - If neither is provided, only public resources are returned. - - If both are provided, both public and private resources are returned. + - If neither is provided, the project ids from the user token are used and all resources that match those ids plus public resources are returned. + - If both are provided, both public and private resources are returned, constrained within that project. - **Write** endpoints require both `virtual-lab-id` and `project-id`. + - **Update** endpoints do not require `virtual-lab-id` or `project-id`, but: + - If neither is provided, the update will be authorized if the `authorized_project_id` of the resource is in the `user_project_ids` extracted from the access token. + - If both are provided the update operation will only be authorized if in that specific project. + - **Delete** endpoints do not require `virtual-lab-id` or `project-id`, but: + - If neither is provided, the delete will be authorized if the `authorized_project_id` of the resource is in the `user_project_ids` extracted from the access token. + - If both are provided the delete operation will only be authorized if in that specific project. + ## Service Admin Group -To call endpoints that modify global resources, the user must belong to a special Keycloak group: `/service/entitycore/admin`. +To call endpoints that modify global resources, the user must belong to a special Keycloak group: `/service/entitycore/admin`. read/write/update/delete operations for the service admin group are not constrained by project ids. ## Caching @@ -34,6 +41,7 @@ Auditing is not yet implemented but can be added later using information retriev - `/brain-region` - `/cell-composition` +- `/consortium` - `/license` - `/mtype` - `/organization` diff --git a/tests/conftest.py b/tests/conftest.py index c2d14a426..59a5bcabd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ from collections.abc import Callable, Iterator from dataclasses import dataclass from datetime import timedelta +from typing import NamedTuple from uuid import UUID import boto3 @@ -55,11 +56,13 @@ AUTH_HEADER_ADMIN, AUTH_HEADER_USER_1, AUTH_HEADER_USER_2, + AUTH_HEADER_USER_1_IDS, PROJECT_HEADERS, PROJECT_ID, TOKEN_ADMIN, TOKEN_USER_1, TOKEN_USER_2, + TOKEN_USER_1_IDS, UNRELATED_PROJECT_HEADERS, UNRELATED_PROJECT_ID, UNRELATED_VIRTUAL_LAB_ID, @@ -147,6 +150,20 @@ def user_context_user_1(): ) +@pytest.fixture +def user_context_only_token_ids(): + return UserContext( + profile=UserProfile( + subject=UUID(USER_SUB_ID_1), + name="Regular User With Token", + ), + expiration=None, + is_authorized=True, + is_service_admin=False, + user_project_ids=[UUID(PROJECT_ID)], + ) + + @pytest.fixture def user_context_user_2(): """Regular authenticated user with different project-id.""" @@ -187,12 +204,14 @@ def _override_check_user_info( user_context_user_1, user_context_user_2, user_context_no_project, + user_context_only_token_ids ): # map (token, project-id) to the expected user_context mapping = { (TOKEN_ADMIN, None): user_context_admin, (TOKEN_ADMIN, UUID(PROJECT_ID)): user_context_admin_with_project, (TOKEN_USER_1, None): user_context_no_project, + (TOKEN_USER_1_IDS, None): user_context_only_token_ids, (TOKEN_USER_1, UUID(PROJECT_ID)): user_context_user_1, (TOKEN_USER_2, UUID(UNRELATED_PROJECT_ID)): user_context_user_2, } @@ -264,6 +283,28 @@ def client_no_project(client_no_auth): return ClientProxy(client_no_auth, headers=AUTH_HEADER_USER_1) +@pytest.fixture +def client_only_project_ids(client_no_auth): + """Return a web client instace, authenticated as a regular user with only user_project_ids.""" + return ClientProxy(client_no_auth, headers=AUTH_HEADER_USER_1_IDS) + + +@pytest.fixture +def clients(client_user_1, client_user_2, client_no_project, client_only_project_ids): + class Clients(NamedTuple): + user_1: ClientProxy + user_2: ClientProxy + no_project: ClientProxy + only_project_ids: ClientProxy + + return Clients( + user_1=client_user_1, + user_2=client_user_2, + no_project=client_no_project, + only_project_ids=client_only_project_ids, + ) + + @pytest.fixture def client(client_user_1): return client_user_1 diff --git a/tests/test_circuit.py b/tests/test_circuit.py index ab5c39f4d..eb35eb2de 100644 --- a/tests/test_circuit.py +++ b/tests/test_circuit.py @@ -14,7 +14,7 @@ delete_entity_contributions, ) -ROUTE = "circuit" +ROUTE = "/circuit" ADMIN_ROUTE = "/admin/circuit" @@ -161,12 +161,10 @@ def test_missing(client): check_missing(ROUTE, client) -def test_authorization(client_user_1, client_user_2, client_no_project, root_circuit_json_data): +def test_authorization(clients, root_circuit_json_data): # using root_circuit_json_data to avoid the implication of creating two circuits # because of the root_circuit_id in circuit_json_data which messes up the check assumptions - check_authorization( - ROUTE, client_user_1, client_user_2, client_no_project, root_circuit_json_data - ) + check_authorization(ROUTE, clients, root_circuit_json_data) def test_pagination(client, create_id): diff --git a/tests/test_electrical_cell_recording.py b/tests/test_electrical_cell_recording.py index 960a34aa1..6ff4f56a6 100644 --- a/tests/test_electrical_cell_recording.py +++ b/tests/test_electrical_cell_recording.py @@ -177,7 +177,7 @@ def test_admin_read_one( def _delete_stimuli(client_admin, trace_id): data = assert_request( client_admin.get, - url=f"{ROUTE}/{trace_id}", + url=f"{ADMIN_ROUTE}/{trace_id}", ).json() for stimulus in data["stimuli"]: @@ -219,10 +219,13 @@ def test_missing(client): def test_authorization( - client_user_1, client_user_2, client_no_project, electrical_cell_recording_json_data + clients, + electrical_cell_recording_json_data, ): check_authorization( - ROUTE, client_user_1, client_user_2, client_no_project, electrical_cell_recording_json_data + ROUTE, + clients, + electrical_cell_recording_json_data, ) diff --git a/tests/test_electrical_recording_stimulus.py b/tests/test_electrical_recording_stimulus.py index af8a55e49..ba1658ee6 100644 --- a/tests/test_electrical_recording_stimulus.py +++ b/tests/test_electrical_recording_stimulus.py @@ -143,13 +143,8 @@ def test_missing(client): check_missing(ROUTE, client) -def test_authorization( - client_user_1, - client_user_2, - client_no_project, - json_data, -): - check_authorization(ROUTE, client_user_1, client_user_2, client_no_project, json_data) +def test_authorization(clients, json_data): + check_authorization(ROUTE, clients, json_data) def test_pagination(client, create_id): diff --git a/tests/test_em_cell_mesh.py b/tests/test_em_cell_mesh.py index 55e01b9a1..4976a94e5 100644 --- a/tests/test_em_cell_mesh.py +++ b/tests/test_em_cell_mesh.py @@ -99,8 +99,8 @@ def test_missing(client): check_missing(ROUTE, client) -def test_authorization(client_user_1, client_user_2, client_no_project, json_data): - check_authorization(ROUTE, client_user_1, client_user_2, client_no_project, json_data) +def test_authorization(clients, json_data): + check_authorization(ROUTE, clients, json_data) def test_pagination(client, create_id): diff --git a/tests/test_experimental_bouton_density.py b/tests/test_experimental_bouton_density.py index 22d88c5bc..13b984154 100644 --- a/tests/test_experimental_bouton_density.py +++ b/tests/test_experimental_bouton_density.py @@ -155,13 +155,8 @@ def test_missing(client): check_missing(ROUTE, client) -def test_authorization( - client_user_1, - client_user_2, - client_no_project, - json_data, -): - check_authorization(ROUTE, client_user_1, client_user_2, client_no_project, json_data) +def test_authorization(clients, json_data): + check_authorization(ROUTE, clients, json_data) def test_pagination(client, create_id): diff --git a/tests/test_experimental_neuron_density.py b/tests/test_experimental_neuron_density.py index 727c46460..c427c3360 100644 --- a/tests/test_experimental_neuron_density.py +++ b/tests/test_experimental_neuron_density.py @@ -192,13 +192,8 @@ def test_missing(client): check_missing(ROUTE, client) -def test_authorization( - client_user_1, - client_user_2, - client_no_project, - json_data, -): - check_authorization(ROUTE, client_user_1, client_user_2, client_no_project, json_data) +def test_authorization(clients, json_data): + check_authorization(ROUTE, clients, json_data) def test_pagination(client, create_id): diff --git a/tests/test_experimental_synapses_per_connection.py b/tests/test_experimental_synapses_per_connection.py index f3d701433..54839d6f4 100644 --- a/tests/test_experimental_synapses_per_connection.py +++ b/tests/test_experimental_synapses_per_connection.py @@ -190,13 +190,8 @@ def test_missing(client): check_missing(ROUTE, client) -def test_authorization( - client_user_1, - client_user_2, - client_no_project, - json_data, -): - check_authorization(ROUTE, client_user_1, client_user_2, client_no_project, json_data) +def test_authorization(clients, json_data): + check_authorization(ROUTE, clients, json_data) def test_pagination(client, create_id): diff --git a/tests/test_ion_channel_model.py b/tests/test_ion_channel_model.py index 33394e8de..03f688763 100644 --- a/tests/test_ion_channel_model.py +++ b/tests/test_ion_channel_model.py @@ -226,9 +226,7 @@ def test_sorted(client: TestClient, subject_id: str, brain_region_id: uuid.UUID) def test_authorization( - client_user_1: TestClient, - client_user_2: TestClient, - client_no_project: TestClient, + clients, subject_id: str, brain_region_id: str, ): @@ -240,7 +238,7 @@ def test_authorization( "brain_region_id": brain_region_id, "subject_id": subject_id, } - check_authorization(ROUTE, client_user_1, client_user_2, client_no_project, json_data) + check_authorization(ROUTE, clients, json_data) def test_paginate(client: TestClient, subject_id: str, brain_region_id: uuid.UUID): diff --git a/tests/test_ion_channel_recording.py b/tests/test_ion_channel_recording.py index aa504cb7d..6ae0b5ec0 100644 --- a/tests/test_ion_channel_recording.py +++ b/tests/test_ion_channel_recording.py @@ -178,7 +178,7 @@ def test_read_one( def _delete_stimuli(client_admin, trace_id): data = assert_request( client_admin.get, - url=f"{ROUTE}/{trace_id}", + url=f"{ADMIN_ROUTE}/{trace_id}", ).json() for stimulus in data["stimuli"]: @@ -223,12 +223,8 @@ def test_missing(client): check_missing(ROUTE, client) -def test_authorization( - client_user_1, client_user_2, client_no_project, ion_channel_recording_json_data -): - check_authorization( - ROUTE, client_user_1, client_user_2, client_no_project, ion_channel_recording_json_data - ) +def test_authorization(clients, ion_channel_recording_json_data): + check_authorization(ROUTE, clients, ion_channel_recording_json_data) def test_pagination(client, ion_channel_recording_json_data): diff --git a/tests/test_morphology.py b/tests/test_morphology.py index 243800c7d..96f3a02f3 100644 --- a/tests/test_morphology.py +++ b/tests/test_morphology.py @@ -511,9 +511,7 @@ def test_query_reconstruction_morphology_species_join(db, client, brain_region_i def test_authorization( - client_user_1, - client_user_2, - client_no_project, + clients, species_id, strain_id, license_id, @@ -529,7 +527,7 @@ def test_authorization( "species_id": species_id, "strain_id": strain_id, } - check_authorization(ROUTE, client_user_1, client_user_2, client_no_project, json_data) + check_authorization(ROUTE, clients, json_data) def test_pagination(db, client, brain_region_id, person_id): diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 30798ece3..de9b48470 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -138,8 +138,8 @@ def test_missing(client): check_missing(ROUTE, client) -def test_authorization(client_user_1, client_user_2, client_no_project, json_data): - check_authorization(ROUTE, client_user_1, client_user_2, client_no_project, json_data) +def test_authorization(clients, json_data): + check_authorization(ROUTE, clients, json_data) def test_pagination(client, create_id): diff --git a/tests/test_simulation_campaign.py b/tests/test_simulation_campaign.py index d229bec7a..9a306ac51 100644 --- a/tests/test_simulation_campaign.py +++ b/tests/test_simulation_campaign.py @@ -90,8 +90,8 @@ def test_missing(client): check_missing(ROUTE, client) -def test_authorization(client_user_1, client_user_2, client_no_project, json_data): - check_authorization(ROUTE, client_user_1, client_user_2, client_no_project, json_data) +def test_authorization(clients, json_data): + check_authorization(ROUTE, clients, json_data) def test_pagination(client, create_id): diff --git a/tests/test_simulation_result.py b/tests/test_simulation_result.py index c4494c4ff..18c5905a7 100644 --- a/tests/test_simulation_result.py +++ b/tests/test_simulation_result.py @@ -130,10 +130,10 @@ def test_missing(client): check_missing(ROUTE, client) -def test_authorization(client_user_1, client_user_2, client_no_project, json_data): +def test_authorization(clients, json_data): # using root_circuit_json_data to avoid the implication of creating two circuits # because of the root_circuit_id in circuit_json_data which messes up the check assumptions - check_authorization(ROUTE, client_user_1, client_user_2, client_no_project, json_data) + check_authorization(ROUTE, clients, json_data) def test_pagination(client, create_id): diff --git a/tests/test_single_neuron_simulation.py b/tests/test_single_neuron_simulation.py index d5deded50..995305454 100644 --- a/tests/test_single_neuron_simulation.py +++ b/tests/test_single_neuron_simulation.py @@ -275,9 +275,7 @@ def test_missing(client, route_id, expected_status_code): ) -def test_authorization( - client_user_1, client_user_2, client_no_project, memodel_id, brain_region_id -): +def test_authorization(clients, memodel_id, brain_region_id): json_data = { "name": "foo", "description": "my-description", @@ -288,7 +286,7 @@ def test_authorization( "seed": 1, "brain_region_id": str(brain_region_id), } - check_authorization(ROUTE, client_user_1, client_user_2, client_no_project, json_data) + check_authorization(ROUTE, clients, json_data) def test_pagination(db, client, brain_region_id, emodel_id, morphology_id, species_id, person_id): diff --git a/tests/test_single_neuron_synaptome.py b/tests/test_single_neuron_synaptome.py index 9276fc305..823f624d0 100644 --- a/tests/test_single_neuron_synaptome.py +++ b/tests/test_single_neuron_synaptome.py @@ -211,8 +211,8 @@ def test_missing(client, route_id, expected_status_code): ) -def test_authorization(client_user_1, client_user_2, client_no_project, json_data): - check_authorization(ROUTE, client_user_1, client_user_2, client_no_project, json_data) +def test_authorization(clients, json_data): + check_authorization(ROUTE, clients, json_data) def test_pagination(db, client, brain_region_id, emodel_id, morphology_id, species_id, person_id): diff --git a/tests/test_single_neuron_synaptome_simulation.py b/tests/test_single_neuron_synaptome_simulation.py index a71631f1c..68fee4a71 100644 --- a/tests/test_single_neuron_synaptome_simulation.py +++ b/tests/test_single_neuron_synaptome_simulation.py @@ -279,8 +279,8 @@ def test_missing(client, route_id, expected_status_code): ) -def test_authorization(client_user_1, client_user_2, client_no_project, json_data): - check_authorization(ROUTE, client_user_1, client_user_2, client_no_project, json_data) +def test_authorization(clients, json_data): + check_authorization(ROUTE, clients, json_data) def test_pagination(db, client, brain_region_id, memodel_id, person_id): diff --git a/tests/test_subject.py b/tests/test_subject.py index 3f2f22dd0..9ff86bbf2 100644 --- a/tests/test_subject.py +++ b/tests/test_subject.py @@ -164,8 +164,8 @@ def test_missing(client): check_missing(ROUTE, client) -def test_authorization(client_user_1, client_user_2, client_no_project, json_data): - check_authorization(ROUTE, client_user_1, client_user_2, client_no_project, json_data) +def test_authorization(clients, json_data): + check_authorization(ROUTE, clients, json_data) def test_pagination(client, create_id): diff --git a/tests/utils.py b/tests/utils.py index e43466784..8224e1809 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -35,10 +35,12 @@ TOKEN_ADMIN = "I'm admin" # noqa: S105 TOKEN_USER_1 = "I'm user 1" # noqa: S105 TOKEN_USER_2 = "I'm user 2" # noqa: S105 +TOKEN_USER_1_IDS = "I'm user 1 with token only" # noqa: S105 AUTH_HEADER_ADMIN = {"Authorization": f"Bearer {TOKEN_ADMIN}"} AUTH_HEADER_USER_1 = {"Authorization": f"Bearer {TOKEN_USER_1}"} AUTH_HEADER_USER_2 = {"Authorization": f"Bearer {TOKEN_USER_2}"} +AUTH_HEADER_USER_1_IDS = {"Authorization": f"Bearer {TOKEN_USER_1_IDS}"} VIRTUAL_LAB_ID = "9c6fba01-2c6f-4eac-893f-f0dc665605c5" PROJECT_ID = "ee86d4a0-eaca-48ca-9788-ddc450250b15" @@ -394,7 +396,7 @@ def check_pagination(route, client, constructor_func): assert len(response_json["data"]) == 2 -def check_authorization(route, client_user_1, client_user_2, client_no_project, json_data): +def check_authorization(route, clients, json_data): """Check the authorization when trying to access the entities. Created entities: @@ -407,7 +409,7 @@ def check_authorization(route, client_user_1, client_user_2, client_no_project, """ # create the entities public_u1_0 = assert_request( - client_user_1.post, + clients.user_1.post, url=route, json=json_data | {"name": "Public u1/0", "authorized_public": True}, ).json() @@ -415,7 +417,7 @@ def check_authorization(route, client_user_1, client_user_2, client_no_project, assert public_u1_0["authorized_project_id"] == PROJECT_ID public_u2_0 = assert_request( - client_user_2.post, + clients.user_2.post, url=route, json=json_data | {"name": "Public u2/0", "authorized_public": True}, ).json() @@ -423,25 +425,25 @@ def check_authorization(route, client_user_1, client_user_2, client_no_project, assert public_u2_0["authorized_project_id"] == UNRELATED_PROJECT_ID private_u2_0 = assert_request( - client_user_2.post, url=route, json=json_data | {"name": "Private u2/0"} + clients.user_2.post, url=route, json=json_data | {"name": "Private u2/0"} ).json() assert private_u2_0["authorized_public"] is False assert private_u2_0["authorized_project_id"] == UNRELATED_PROJECT_ID private_u1_0 = assert_request( - client_user_1.post, url=route, json=json_data | {"name": "Private u1/0"} + clients.user_1.post, url=route, json=json_data | {"name": "Private u1/0"} ).json() assert private_u1_0["authorized_public"] is False assert private_u1_0["authorized_project_id"] == PROJECT_ID private_u1_1 = assert_request( - client_user_1.post, url=route, json=json_data | {"name": "Private u1/1"} + clients.user_1.post, url=route, json=json_data | {"name": "Private u1/1"} ).json() assert private_u1_1["authorized_public"] is False assert private_u1_1["authorized_project_id"] == PROJECT_ID # only return results that matches the desired project, and public ones - data = assert_request(client_user_1.get, url=route).json()["data"] + data = assert_request(clients.user_1.get, url=route).json()["data"] assert len(data) == 4 assert {row["id"] for row in data} == { public_u1_0["id"], @@ -452,13 +454,13 @@ def check_authorization(route, client_user_1, client_user_2, client_no_project, # cannot access the private entity of the other user assert_request( - client_user_1.get, + clients.user_1.get, url=f"{route}/{private_u2_0['id']}", expected_status_code=404, ) # client_no_project can get public entities only - data = assert_request(client_no_project.get, url=route).json()["data"] + data = assert_request(clients.no_project.get, url=route).json()["data"] assert len(data) == 2 assert {row["id"] for row in data} == { public_u1_0["id"], @@ -467,7 +469,7 @@ def check_authorization(route, client_user_1, client_user_2, client_no_project, # client_user_1 wants only public results data = assert_request( - client_user_1.get, + clients.user_1.get, url=route, params={"authorized_public": True}, ).json()["data"] @@ -476,7 +478,7 @@ def check_authorization(route, client_user_1, client_user_2, client_no_project, # client_user_1 wants only private (and accessible) results data = assert_request( - client_user_1.get, + clients.user_1.get, url=route, params={"authorized_public": False}, ).json()["data"] @@ -485,7 +487,7 @@ def check_authorization(route, client_user_1, client_user_2, client_no_project, # client_user_1 wants only their own entities (private or public) data = assert_request( - client_user_1.get, + clients.user_1.get, url=route, params={"authorized_project_id": PROJECT_ID}, ).json()["data"] @@ -498,7 +500,7 @@ def check_authorization(route, client_user_1, client_user_2, client_no_project, # client_user_1 can get entities in other projects only if they are public data = assert_request( - client_user_1.get, + clients.user_1.get, url=route, params={"authorized_project_id": UNRELATED_PROJECT_ID}, ).json()["data"] @@ -507,7 +509,7 @@ def check_authorization(route, client_user_1, client_user_2, client_no_project, # client_user_1 can get entities in other projects only if they are public (again) data = assert_request( - client_user_1.get, + clients.user_1.get, url=route, params={"authorized_public": True, "authorized_project_id": UNRELATED_PROJECT_ID}, ).json()["data"] @@ -516,7 +518,7 @@ def check_authorization(route, client_user_1, client_user_2, client_no_project, # client_user_1 can get entities in other projects only if they are public (no results) data = assert_request( - client_user_1.get, + clients.user_1.get, url=route, params={"authorized_public": False, "authorized_project_id": UNRELATED_PROJECT_ID}, ).json()["data"] @@ -524,7 +526,7 @@ def check_authorization(route, client_user_1, client_user_2, client_no_project, # client_user_2 wants only their own entities (private or public) data = assert_request( - client_user_2.get, + clients.user_2.get, url=route, params={"authorized_project_id": UNRELATED_PROJECT_ID}, ).json()["data"] @@ -533,7 +535,7 @@ def check_authorization(route, client_user_1, client_user_2, client_no_project, # client_user_1 wants only their own public entities data = assert_request( - client_user_1.get, + clients.user_1.get, url=route, params={"authorized_public": True, "authorized_project_id": PROJECT_ID}, ).json()["data"] @@ -542,13 +544,46 @@ def check_authorization(route, client_user_1, client_user_2, client_no_project, # client_user_1 wants only their own private entities data = assert_request( - client_user_1.get, + clients.user_1.get, url=route, params={"authorized_public": False, "authorized_project_id": PROJECT_ID}, ).json()["data"] assert len(data) == 2 assert {row["id"] for row in data} == {private_u1_0["id"], private_u1_1["id"]} + # if there is no context the user, but a token with user_project_ids + # should still be able to fetch the resource + + data = assert_request( + clients.only_project_ids.get, + url=f"{route}/{public_u1_0['id']}", + ).json() + assert data["id"] == str(public_u1_0['id']) + + data = assert_request( + clients.only_project_ids.get, + url=f"{route}/{public_u2_0['id']}", + ).json() + assert data["id"] == str(public_u2_0['id']) + + data = assert_request( + clients.only_project_ids.get, + url=f"{route}/{private_u1_0['id']}", + ).json() + assert data["id"] == str(private_u1_0['id']) + + data = assert_request( + clients.only_project_ids.get, + url=f"{route}/{private_u2_0['id']}", + expected_status_code=404, + ).json() + + data = assert_request( + clients.only_project_ids.get, + url=route, + ).json()["data"] + assert len(data) == 4 + def check_brain_region_filter(route, client, db, brain_region_hierarchy_id, create_model_function): db_hierarchy = db.get(BrainRegionHierarchy, brain_region_hierarchy_id) @@ -704,7 +739,7 @@ def count_db_class(db, db_class): def delete_entity_contributions(client_admin, entity_route, entity_id): data = assert_request( client_admin.get, - url=f"{entity_route}/{entity_id}", + url=f"/admin{entity_route}/{entity_id}", ).json() for contribution in data["contributions"]: @@ -718,7 +753,7 @@ def delete_entity_contributions(client_admin, entity_route, entity_id): def delete_entity_assets(client_admin, entity_route, entity_id): data = assert_request( client_admin.get, - url=f"{entity_route}/{entity_id}", + url=f"/admin{entity_route}/{entity_id}", ).json() for json_asset in data["assets"]: