Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support signing with CAI Server in sdk #240

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/stability_sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def generate(
guidance_strength: float = 0.0,
preset: Optional[str] = None,
return_request: bool = False,
content_credentials_add_default: bool = False,
) -> Dict[int, List[Any]]:
"""
Generate an image from a set of weighted prompts.
Expand Down Expand Up @@ -164,7 +165,7 @@ def generate(
start_schedule = 1.0 - init_strength
image_params = self._build_image_params(width, height, sampler, steps, seed, samples, cfg_scale,
start_schedule, init_noise_scale, masked_area_init,
guidance_preset, guidance_cuts, guidance_strength)
guidance_preset, guidance_cuts, guidance_strength, content_credentials_add_default)

extras = Struct()
if preset and preset.lower() != 'none':
Expand Down Expand Up @@ -235,7 +236,7 @@ def inpaint(
start_schedule = 1.0-init_strength
image_params = self._build_image_params(width, height, sampler, steps, seed, samples, cfg_scale,
start_schedule, init_noise_scale, masked_area_init,
guidance_preset, guidance_cuts, guidance_strength)
guidance_preset, guidance_cuts, guidance_strength, ccontent_credentials_add_default=False)

extras = Struct()
if preset and preset.lower() != 'none':
Expand Down Expand Up @@ -538,7 +539,7 @@ def _adjust_request_for_retry(self, request: generation.Request, attempt: int):

def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_scale,
schedule_start, init_noise_scale, masked_area_init,
guidance_preset, guidance_cuts, guidance_strength):
guidance_preset, guidance_cuts, guidance_strength, ccontent_credentials_add_default):

if not seed:
seed = [random.randrange(0, 4294967295)]
Expand Down Expand Up @@ -568,6 +569,12 @@ def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_
)
]
)
# empty Content Credentials Parameters will result in images not being signed by the Content Credentials server
content_credentials_params = generation.ContentCredentialsParameters()
if ccontent_credentials_add_default:
content_credentials_params = generation.ContentCredentialsParameters(
model_metadata=generation._CONTENTCREDENTIALSPARAMETERS_MODELMETADATA.values_by_name[
'MODEL_METADATA_SIGN_WITH_ENGINE_ID'].number)

return generation.ImageParameters(
transform=None if sampler is None else generation.TransformType(diffusion=sampler),
Expand All @@ -578,6 +585,7 @@ def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_
samples=samples,
masked_area_init=masked_area_init,
parameters=[generation.StepParameter(**step_parameters)],
content_credentials_parameters=content_credentials_params
)

def _process_response(self, response) -> Dict[int, List[Any]]:
Expand Down
16 changes: 15 additions & 1 deletion src/stability_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def generate(
guidance_strength: Optional[float] = None,
guidance_prompt: Union[str, generation.Prompt] = None,
guidance_models: List[str] = None,
content_credentials_add_default: bool = False,
) -> Generator[generation.Answer, None, None]:
"""
Generate images from a prompt.
Expand All @@ -194,6 +195,7 @@ def generate(
:param guidance_strength: Strength of the guidance. We recommend values in range [0.0,1.0]. A good default is 0.25
:param guidance_prompt: Prompt to use for guidance, defaults to `prompt` argument (above) if not specified.
:param guidance_models: Models to use for guidance.
:param content_credentials_add_default: Add default Content Credentials or not.
:return: Generator of Answer objects.
"""
if (prompt is None) and (init_image is None):
Expand Down Expand Up @@ -277,6 +279,13 @@ def generate(
transform=None
if sampler:
transform=generation.TransformType(diffusion=sampler)

# empty Content Credential Parameters will result in images not being signed by the Content Credential server
content_credentials_params = generation.ContentCredentialsParameters()
if content_credentials_add_default:
content_credentials_params = generation.ContentCredentialsParameters(
model_metadata=generation._CONTENTCREDENTIALSPARAMETERS_MODELMETADATA.values_by_name[
'MODEL_METADATA_SIGN_WITH_ENGINE_ID'].number)

image_parameters=generation.ImageParameters(
transform=transform,
Expand All @@ -286,6 +295,7 @@ def generate(
steps=steps,
samples=samples,
parameters=[generation.StepParameter(**step_parameters)],
content_credentials_parameters=content_credentials_params,
)

return self.emit_request(prompt=prompts, image_parameters=image_parameters)
Expand Down Expand Up @@ -498,6 +508,9 @@ def process_cli(logger: logging.Logger = None,
parser_generate.add_argument(
"--width", "-W", type=int, default=512, help="[512] width of image"
)
parser_generate.add_argument(
"--content_credentials_add_default", type=bool, action='store_true', default=False, help="Attatch a signed manifest to artifacts using C2PA. The default manifest will contain engine id and publisher name (Stability AI)"
)
parser_generate.add_argument(
"--start_schedule",
type=float,
Expand Down Expand Up @@ -626,11 +639,12 @@ def process_cli(logger: logging.Logger = None,
"width": args.width,
"start_schedule": args.start_schedule,
"end_schedule": args.end_schedule,
"cfg_scale": args.cfg_scale,
"cfg_scale": args.cfg_scale,
"seed": args.seed,
"samples": args.num_samples,
"init_image": args.init_image,
"mask_image": args.mask_image,
"content_credentials_add_default": args.content_credentials_add_default,
}

if args.sampler:
Expand Down
2 changes: 1 addition & 1 deletion src/stability_sdk/interfaces
23 changes: 22 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def ChainGenerate(self, chain: generation.ChainRequest, **kwargs) -> Generator[g
for answer in self.Generate(stage.request):
artifacts.extend(answer.artifacts)
for artifact in artifacts:
yield generation.Answer(artifacts=[artifact])
yield generation.Answer(artifacts=[artifact])

def Generate(self, request: generation.Request, **kwargs) -> Generator[generation.Answer, None, None]:
if request.HasField("image"):
Expand Down Expand Up @@ -99,6 +99,27 @@ def test_api_generate():
assert isinstance(image, Image.Image)
assert image.size == (width, height)

def test_api_generate_content_credentials_signing_set():
class ContentCredentialsMockStub(MockStub):
def Generate(self, request: generation.Request, **kwargs) -> Generator[generation.Answer, None, None]:
assert request.image.content_credentials_parameters.model_metadata == \
generation._CONTENTCREDENTIALSPARAMETERS_MODELMETADATA.values_by_name['MODEL_METADATA_SIGN_WITH_ENGINE_ID'].number
return super().Generate(request, **kwargs)
api = Context(stub=ContentCredentialsMockStub())
width, height = 512, 768
results = api.generate(prompts=["foo bar"], weights=[1.0], width=width, height=height, content_credentials_add_default=True)

def test_api_generate_content_credentials_signing_unset():
class ContentCredentialsMockStub(MockStub):
def Generate(self, request: generation.Request, **kwargs) -> Generator[generation.Answer, None, None]:
assert request.image.content_credentials_parameters.model_metadata == \
generation._CONTENTCREDENTIALSPARAMETERS_MODELMETADATA.values_by_name['MODEL_METADATA_UNSPECIFIED'].number
return super().Generate(request, **kwargs)
api = Context(stub=ContentCredentialsMockStub())
width, height = 512, 768
# content_credentials_add_default should default to false.
results = api.generate(prompts=["foo bar"], weights=[1.0], width=width, height=height)

def test_api_inpaint():
api = Context(stub=MockStub())
width, height = 512, 768
Expand Down
33 changes: 32 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,40 @@
from PIL import Image
from typing import Generator
from google.protobuf.struct_pb2 import Struct
from typing import Generator, Optional

from stability_sdk import client
from stability_sdk.api import generation

class TestStabilityInferenceImageParameters(client.StabilityInference):
def emit_request(
self,
prompt: generation.Prompt,
image_parameters: generation.ImageParameters,
extra_parameters: Optional[Struct] = None,
engine_id: str = None,
request_id: str = None,
):
return image_parameters


def test_content_credentials_not_set():
class_instance = TestStabilityInferenceImageParameters(
host='foo.bar.baz', key='thisIsNotARealKey')
image_params = class_instance.generate(prompt="foo bar")

assert image_params.content_credentials_parameters.model_metadata == \
generation._CONTENTCREDENTIALSPARAMETERS_MODELMETADATA.values_by_name[
'MODEL_METADATA_UNSPECIFIED'].number

def test_content_credentials_set():
class_instance = TestStabilityInferenceImageParameters(
host='foo.bar.baz', key='thisIsNotARealKey')
image_params = class_instance.generate(prompt="foo bar",
content_credentials_add_default=True)

assert image_params.content_credentials_parameters.model_metadata == \
generation._CONTENTCREDENTIALSPARAMETERS_MODELMETADATA.values_by_name[
'MODEL_METADATA_SIGN_WITH_ENGINE_ID'].number

def test_StabilityInference_init():
_ = client.StabilityInference(key='thisIsNotARealKey')
Expand Down