diff --git a/src/stability_sdk/api.py b/src/stability_sdk/api.py index 2057961f..44e4c4d6 100644 --- a/src/stability_sdk/api.py +++ b/src/stability_sdk/api.py @@ -122,6 +122,7 @@ def generate( guidance_strength: float = 0.0, preset: Optional[str] = None, return_request: bool = False, + content_credentials_add_default: bool = False, ) -> Dict[int, List[Any]]: """ Generate an image from a set of weighted prompts. @@ -164,7 +165,7 @@ def generate( start_schedule = 1.0 - init_strength image_params = self._build_image_params(width, height, sampler, steps, seed, samples, cfg_scale, start_schedule, init_noise_scale, masked_area_init, - guidance_preset, guidance_cuts, guidance_strength) + guidance_preset, guidance_cuts, guidance_strength, content_credentials_add_default) extras = Struct() if preset and preset.lower() != 'none': @@ -235,7 +236,7 @@ def inpaint( start_schedule = 1.0-init_strength image_params = self._build_image_params(width, height, sampler, steps, seed, samples, cfg_scale, start_schedule, init_noise_scale, masked_area_init, - guidance_preset, guidance_cuts, guidance_strength) + guidance_preset, guidance_cuts, guidance_strength, ccontent_credentials_add_default=False) extras = Struct() if preset and preset.lower() != 'none': @@ -538,7 +539,7 @@ def _adjust_request_for_retry(self, request: generation.Request, attempt: int): def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_scale, schedule_start, init_noise_scale, masked_area_init, - guidance_preset, guidance_cuts, guidance_strength): + guidance_preset, guidance_cuts, guidance_strength, ccontent_credentials_add_default): if not seed: seed = [random.randrange(0, 4294967295)] @@ -568,6 +569,12 @@ def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_ ) ] ) + # empty Content Credentials Parameters will result in images not being signed by the Content Credentials server + content_credentials_params = generation.ContentCredentialsParameters() + if ccontent_credentials_add_default: + content_credentials_params = generation.ContentCredentialsParameters( + model_metadata=generation._CONTENTCREDENTIALSPARAMETERS_MODELMETADATA.values_by_name[ + 'MODEL_METADATA_SIGN_WITH_ENGINE_ID'].number) return generation.ImageParameters( transform=None if sampler is None else generation.TransformType(diffusion=sampler), @@ -578,6 +585,7 @@ def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_ samples=samples, masked_area_init=masked_area_init, parameters=[generation.StepParameter(**step_parameters)], + content_credentials_parameters=content_credentials_params ) def _process_response(self, response) -> Dict[int, List[Any]]: diff --git a/src/stability_sdk/client.py b/src/stability_sdk/client.py index 20547f00..ea958085 100644 --- a/src/stability_sdk/client.py +++ b/src/stability_sdk/client.py @@ -171,6 +171,7 @@ def generate( guidance_strength: Optional[float] = None, guidance_prompt: Union[str, generation.Prompt] = None, guidance_models: List[str] = None, + content_credentials_add_default: bool = False, ) -> Generator[generation.Answer, None, None]: """ Generate images from a prompt. @@ -194,6 +195,7 @@ def generate( :param guidance_strength: Strength of the guidance. We recommend values in range [0.0,1.0]. A good default is 0.25 :param guidance_prompt: Prompt to use for guidance, defaults to `prompt` argument (above) if not specified. :param guidance_models: Models to use for guidance. + :param content_credentials_add_default: Add default Content Credentials or not. :return: Generator of Answer objects. """ if (prompt is None) and (init_image is None): @@ -277,6 +279,13 @@ def generate( transform=None if sampler: transform=generation.TransformType(diffusion=sampler) + + # empty Content Credential Parameters will result in images not being signed by the Content Credential server + content_credentials_params = generation.ContentCredentialsParameters() + if content_credentials_add_default: + content_credentials_params = generation.ContentCredentialsParameters( + model_metadata=generation._CONTENTCREDENTIALSPARAMETERS_MODELMETADATA.values_by_name[ + 'MODEL_METADATA_SIGN_WITH_ENGINE_ID'].number) image_parameters=generation.ImageParameters( transform=transform, @@ -286,6 +295,7 @@ def generate( steps=steps, samples=samples, parameters=[generation.StepParameter(**step_parameters)], + content_credentials_parameters=content_credentials_params, ) return self.emit_request(prompt=prompts, image_parameters=image_parameters) @@ -498,6 +508,9 @@ def process_cli(logger: logging.Logger = None, parser_generate.add_argument( "--width", "-W", type=int, default=512, help="[512] width of image" ) + parser_generate.add_argument( + "--content_credentials_add_default", type=bool, action='store_true', default=False, help="Attatch a signed manifest to artifacts using C2PA. The default manifest will contain engine id and publisher name (Stability AI)" + ) parser_generate.add_argument( "--start_schedule", type=float, @@ -626,11 +639,12 @@ def process_cli(logger: logging.Logger = None, "width": args.width, "start_schedule": args.start_schedule, "end_schedule": args.end_schedule, - "cfg_scale": args.cfg_scale, + "cfg_scale": args.cfg_scale, "seed": args.seed, "samples": args.num_samples, "init_image": args.init_image, "mask_image": args.mask_image, + "content_credentials_add_default": args.content_credentials_add_default, } if args.sampler: diff --git a/src/stability_sdk/interfaces b/src/stability_sdk/interfaces index f3a50851..bc9b859d 160000 --- a/src/stability_sdk/interfaces +++ b/src/stability_sdk/interfaces @@ -1 +1 @@ -Subproject commit f3a50851f8ea158fef1b1d76661cfd9a8cf83e01 +Subproject commit bc9b859d0e6446099ea79bd2d0d2ade2decb9385 diff --git a/tests/test_api.py b/tests/test_api.py index b8a6c740..4fbd59ab 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -32,7 +32,7 @@ def ChainGenerate(self, chain: generation.ChainRequest, **kwargs) -> Generator[g for answer in self.Generate(stage.request): artifacts.extend(answer.artifacts) for artifact in artifacts: - yield generation.Answer(artifacts=[artifact]) + yield generation.Answer(artifacts=[artifact]) def Generate(self, request: generation.Request, **kwargs) -> Generator[generation.Answer, None, None]: if request.HasField("image"): @@ -99,6 +99,27 @@ def test_api_generate(): assert isinstance(image, Image.Image) assert image.size == (width, height) +def test_api_generate_content_credentials_signing_set(): + class ContentCredentialsMockStub(MockStub): + def Generate(self, request: generation.Request, **kwargs) -> Generator[generation.Answer, None, None]: + assert request.image.content_credentials_parameters.model_metadata == \ + generation._CONTENTCREDENTIALSPARAMETERS_MODELMETADATA.values_by_name['MODEL_METADATA_SIGN_WITH_ENGINE_ID'].number + return super().Generate(request, **kwargs) + api = Context(stub=ContentCredentialsMockStub()) + width, height = 512, 768 + results = api.generate(prompts=["foo bar"], weights=[1.0], width=width, height=height, content_credentials_add_default=True) + +def test_api_generate_content_credentials_signing_unset(): + class ContentCredentialsMockStub(MockStub): + def Generate(self, request: generation.Request, **kwargs) -> Generator[generation.Answer, None, None]: + assert request.image.content_credentials_parameters.model_metadata == \ + generation._CONTENTCREDENTIALSPARAMETERS_MODELMETADATA.values_by_name['MODEL_METADATA_UNSPECIFIED'].number + return super().Generate(request, **kwargs) + api = Context(stub=ContentCredentialsMockStub()) + width, height = 512, 768 + # content_credentials_add_default should default to false. + results = api.generate(prompts=["foo bar"], weights=[1.0], width=width, height=height) + def test_api_inpaint(): api = Context(stub=MockStub()) width, height = 512, 768 diff --git a/tests/test_client.py b/tests/test_client.py index 69b93ca8..bff5ff14 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,9 +1,40 @@ from PIL import Image -from typing import Generator +from google.protobuf.struct_pb2 import Struct +from typing import Generator, Optional from stability_sdk import client from stability_sdk.api import generation +class TestStabilityInferenceImageParameters(client.StabilityInference): + def emit_request( + self, + prompt: generation.Prompt, + image_parameters: generation.ImageParameters, + extra_parameters: Optional[Struct] = None, + engine_id: str = None, + request_id: str = None, + ): + return image_parameters + + +def test_content_credentials_not_set(): + class_instance = TestStabilityInferenceImageParameters( + host='foo.bar.baz', key='thisIsNotARealKey') + image_params = class_instance.generate(prompt="foo bar") + + assert image_params.content_credentials_parameters.model_metadata == \ + generation._CONTENTCREDENTIALSPARAMETERS_MODELMETADATA.values_by_name[ + 'MODEL_METADATA_UNSPECIFIED'].number + +def test_content_credentials_set(): + class_instance = TestStabilityInferenceImageParameters( + host='foo.bar.baz', key='thisIsNotARealKey') + image_params = class_instance.generate(prompt="foo bar", + content_credentials_add_default=True) + + assert image_params.content_credentials_parameters.model_metadata == \ + generation._CONTENTCREDENTIALSPARAMETERS_MODELMETADATA.values_by_name[ + 'MODEL_METADATA_SIGN_WITH_ENGINE_ID'].number def test_StabilityInference_init(): _ = client.StabilityInference(key='thisIsNotARealKey')