Skip to content

Commit

Permalink
Support signing with CAI Server in sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
Richard Kennedy committed Jun 12, 2023
1 parent a72b5f1 commit 1dd9d49
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 5 deletions.
14 changes: 11 additions & 3 deletions src/stability_sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def generate(
guidance_strength: float = 0.0,
preset: Optional[str] = None,
return_request: bool = False,
sign_with_cai: bool = False,
) -> Dict[int, List[Any]]:
"""
Generate an image from a set of weighted prompts.
Expand Down Expand Up @@ -164,7 +165,7 @@ def generate(
start_schedule = 1.0 - init_strength
image_params = self._build_image_params(width, height, sampler, steps, seed, samples, cfg_scale,
start_schedule, init_noise_scale, masked_area_init,
guidance_preset, guidance_cuts, guidance_strength)
guidance_preset, guidance_cuts, guidance_strength, sign_with_cai)

extras = Struct()
if preset and preset.lower() != 'none':
Expand Down Expand Up @@ -235,7 +236,7 @@ def inpaint(
start_schedule = 1.0-init_strength
image_params = self._build_image_params(width, height, sampler, steps, seed, samples, cfg_scale,
start_schedule, init_noise_scale, masked_area_init,
guidance_preset, guidance_cuts, guidance_strength)
guidance_preset, guidance_cuts, guidance_strength, sign_with_cai=False)

extras = Struct()
if preset and preset.lower() != 'none':
Expand Down Expand Up @@ -538,7 +539,7 @@ def _adjust_request_for_retry(self, request: generation.Request, attempt: int):

def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_scale,
schedule_start, init_noise_scale, masked_area_init,
guidance_preset, guidance_cuts, guidance_strength):
guidance_preset, guidance_cuts, guidance_strength, sign_with_cai):

if not seed:
seed = [random.randrange(0, 4294967295)]
Expand Down Expand Up @@ -568,6 +569,12 @@ def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_
)
]
)
# empty CAI Parameters will result in images not being signed by the CAI server
caip = generation.CAIParameters()
if sign_with_cai:
caip = generation.CAIParameters(
model_metadata=generation._CAIPARAMETERS_MODELMETADATA.values_by_name[
'SIGN_WITH_ENGINE_ID'].number)

return generation.ImageParameters(
transform=None if sampler is None else generation.TransformType(diffusion=sampler),
Expand All @@ -578,6 +585,7 @@ def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_
samples=samples,
masked_area_init=masked_area_init,
parameters=[generation.StepParameter(**step_parameters)],
cai_parameters=caip
)

def _process_response(self, response) -> Dict[int, List[Any]]:
Expand Down
6 changes: 5 additions & 1 deletion src/stability_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,9 @@ def process_cli(logger: logging.Logger = None,
parser_generate.add_argument(
"--width", "-W", type=int, default=512, help="[512] width of image"
)
parser_generate.add_argument(
"--sign_with_cai", type=bool, default=False, help="Sign artifacts using C2PA to include providence data containing engine id"
)
parser_generate.add_argument(
"--start_schedule",
type=float,
Expand Down Expand Up @@ -626,11 +629,12 @@ def process_cli(logger: logging.Logger = None,
"width": args.width,
"start_schedule": args.start_schedule,
"end_schedule": args.end_schedule,
"cfg_scale": args.cfg_scale,
"cfg_scale": args.cfg_scale,
"seed": args.seed,
"samples": args.num_samples,
"init_image": args.init_image,
"mask_image": args.mask_image,
"sign_with_cai": args.sign_with_cai,
}

if args.sampler:
Expand Down
24 changes: 23 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def ChainGenerate(self, chain: generation.ChainRequest, **kwargs) -> Generator[g
for answer in self.Generate(stage.request):
artifacts.extend(answer.artifacts)
for artifact in artifacts:
yield generation.Answer(artifacts=[artifact])
yield generation.Answer(artifacts=[artifact])

def Generate(self, request: generation.Request, **kwargs) -> Generator[generation.Answer, None, None]:
if request.HasField("image"):
Expand Down Expand Up @@ -99,6 +99,28 @@ def test_api_generate():
assert isinstance(image, Image.Image)
assert image.size == (width, height)

def test_api_generate_cai_signing_set():
class CAIMockStub(MockStub):
def Generate(self, request: generation.Request, **kwargs) -> Generator[generation.Answer, None, None]:
assert request.image.cai_parameters.model_metadata == \
generation._CAIPARAMETERS_MODELMETADATA.values_by_name['SIGN_WITH_ENGINE_ID'].number
return super().Generate(request, **kwargs)
api = Context(stub=CAIMockStub())
width, height = 512, 768
results = api.generate(prompts=["foo bar"], weights=[1.0], width=width, height=height, sign_with_cai=True)

def test_api_generate_cai_signing_unset():
class CAIMockStub(MockStub):
def Generate(self, request: generation.Request, **kwargs) -> Generator[generation.Answer, None, None]:
assert request.image.cai_parameters.model_metadata == \
generation._CAIPARAMETERS_MODELMETADATA.values_by_name['METADATA_UNSPECIFIED'].number
return super().Generate(request, **kwargs)
api = Context(stub=CAIMockStub())
width, height = 512, 768
# sign_with_cai should default to false.
results = api.generate(prompts=["foo bar"], weights=[1.0], width=width, height=height)


def test_api_inpaint():
api = Context(stub=MockStub())
width, height = 512, 768
Expand Down

0 comments on commit 1dd9d49

Please sign in to comment.