Skip to content

Commit

Permalink
Update API usage to only reference C2PA, not CAI
Browse files Browse the repository at this point in the history
  • Loading branch information
Richard Kennedy committed Jun 30, 2023
1 parent 03f2934 commit 3560d57
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 33 deletions.
20 changes: 10 additions & 10 deletions src/stability_sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def generate(
guidance_strength: float = 0.0,
preset: Optional[str] = None,
return_request: bool = False,
cai_add_default_manifest: bool = False,
c2pa_add_default_manifest: bool = False,
) -> Dict[int, List[Any]]:
"""
Generate an image from a set of weighted prompts.
Expand Down Expand Up @@ -165,7 +165,7 @@ def generate(
start_schedule = 1.0 - init_strength
image_params = self._build_image_params(width, height, sampler, steps, seed, samples, cfg_scale,
start_schedule, init_noise_scale, masked_area_init,
guidance_preset, guidance_cuts, guidance_strength, cai_add_default_manifest)
guidance_preset, guidance_cuts, guidance_strength, c2pa_add_default_manifest)

extras = Struct()
if preset and preset.lower() != 'none':
Expand Down Expand Up @@ -236,7 +236,7 @@ def inpaint(
start_schedule = 1.0-init_strength
image_params = self._build_image_params(width, height, sampler, steps, seed, samples, cfg_scale,
start_schedule, init_noise_scale, masked_area_init,
guidance_preset, guidance_cuts, guidance_strength, cai_add_default_manifest=False)
guidance_preset, guidance_cuts, guidance_strength, c2pa_add_default_manifest=False)

extras = Struct()
if preset and preset.lower() != 'none':
Expand Down Expand Up @@ -539,7 +539,7 @@ def _adjust_request_for_retry(self, request: generation.Request, attempt: int):

def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_scale,
schedule_start, init_noise_scale, masked_area_init,
guidance_preset, guidance_cuts, guidance_strength, cai_add_default_manifest):
guidance_preset, guidance_cuts, guidance_strength, c2pa_add_default_manifest):

if not seed:
seed = [random.randrange(0, 4294967295)]
Expand Down Expand Up @@ -569,11 +569,11 @@ def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_
)
]
)
# empty CAI Parameters will result in images not being signed by the CAI server
caip = generation.CAIParameters()
if cai_add_default_manifest:
caip = generation.CAIParameters(
model_metadata=generation._CAIPARAMETERS_MODELMETADATA.values_by_name[
# empty C2PA Parameters will result in images not being signed by the C2PA server
c2pa_params = generation.C2PAParameters()
if c2pa_add_default_manifest:
c2pa_params = generation.C2PAParameters(
model_metadata=generation._C2PAPARAMETERS_MODELMETADATA.values_by_name[
'MODEL_METADATA_SIGN_WITH_ENGINE_ID'].number)

return generation.ImageParameters(
Expand All @@ -585,7 +585,7 @@ def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_
samples=samples,
masked_area_init=masked_area_init,
parameters=[generation.StepParameter(**step_parameters)],
cai_parameters=caip
c2pa_parameters=c2pa_params
)

def _process_response(self, response) -> Dict[int, List[Any]]:
Expand Down
20 changes: 10 additions & 10 deletions src/stability_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def generate(
guidance_strength: Optional[float] = None,
guidance_prompt: Union[str, generation.Prompt] = None,
guidance_models: List[str] = None,
cai_add_default_manifest: bool = False,
c2pa_add_default_manifest: bool = False,
) -> Generator[generation.Answer, None, None]:
"""
Generate images from a prompt.
Expand All @@ -195,7 +195,7 @@ def generate(
:param guidance_strength: Strength of the guidance. We recommend values in range [0.0,1.0]. A good default is 0.25
:param guidance_prompt: Prompt to use for guidance, defaults to `prompt` argument (above) if not specified.
:param guidance_models: Models to use for guidance.
:param cai_add_default_manifest: Add default C2PA manifest or not.
:param c2pa_add_default_manifest: Add default C2PA manifest or not.
:return: Generator of Answer objects.
"""
if (prompt is None) and (init_image is None):
Expand Down Expand Up @@ -280,11 +280,11 @@ def generate(
if sampler:
transform=generation.TransformType(diffusion=sampler)

# empty CAI Parameters will result in images not being signed by the CAI server
caip = generation.CAIParameters()
if cai_add_default_manifest:
caip = generation.CAIParameters(
model_metadata=generation._CAIPARAMETERS_MODELMETADATA.values_by_name[
# empty C2PA Parameters will result in images not being signed by the C2PA server
c2pa_params = generation.C2PAParameters()
if c2pa_add_default_manifest:
c2pa_params = generation.C2PAParameters(
model_metadata=generation._C2PAPARAMETERS_MODELMETADATA.values_by_name[
'MODEL_METADATA_SIGN_WITH_ENGINE_ID'].number)

image_parameters=generation.ImageParameters(
Expand All @@ -295,7 +295,7 @@ def generate(
steps=steps,
samples=samples,
parameters=[generation.StepParameter(**step_parameters)],
cai_parameters=caip,
c2pa_parameters=c2pa_params,
)

return self.emit_request(prompt=prompts, image_parameters=image_parameters)
Expand Down Expand Up @@ -509,7 +509,7 @@ def process_cli(logger: logging.Logger = None,
"--width", "-W", type=int, default=512, help="[512] width of image"
)
parser_generate.add_argument(
"--cai_add_default_manifest", type=bool, default=False, help="Attatch a signed manifest to artifacts using C2PA. The default manifest will contain engine id and publisher name (Stability AI)"
"--c2pa_add_default_manifest", type=bool, action='store_true', default=False, help="Attatch a signed manifest to artifacts using C2PA. The default manifest will contain engine id and publisher name (Stability AI)"
)
parser_generate.add_argument(
"--start_schedule",
Expand Down Expand Up @@ -644,7 +644,7 @@ def process_cli(logger: logging.Logger = None,
"samples": args.num_samples,
"init_image": args.init_image,
"mask_image": args.mask_image,
"cai_add_default_manifest": args.cai_add_default_manifest,
"c2pa_add_default_manifest": args.c2pa_add_default_manifest,
}

if args.sampler:
Expand Down
2 changes: 1 addition & 1 deletion src/stability_sdk/interfaces
24 changes: 12 additions & 12 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,25 +99,25 @@ def test_api_generate():
assert isinstance(image, Image.Image)
assert image.size == (width, height)

def test_api_generate_cai_signing_set():
class CAIMockStub(MockStub):
def test_api_generate_c2pa_signing_set():
class C2PAMockStub(MockStub):
def Generate(self, request: generation.Request, **kwargs) -> Generator[generation.Answer, None, None]:
assert request.image.cai_parameters.model_metadata == \
generation._CAIPARAMETERS_MODELMETADATA.values_by_name['MODEL_METADATA_SIGN_WITH_ENGINE_ID'].number
assert request.image.c2pa_parameters.model_metadata == \
generation._C2PAPARAMETERS_MODELMETADATA.values_by_name['MODEL_METADATA_SIGN_WITH_ENGINE_ID'].number
return super().Generate(request, **kwargs)
api = Context(stub=CAIMockStub())
api = Context(stub=C2PAMockStub())
width, height = 512, 768
results = api.generate(prompts=["foo bar"], weights=[1.0], width=width, height=height, cai_add_default_manifest=True)
results = api.generate(prompts=["foo bar"], weights=[1.0], width=width, height=height, c2pa_add_default_manifest=True)

def test_api_generate_cai_signing_unset():
class CAIMockStub(MockStub):
def test_api_generate_c2pa_signing_unset():
class C2PAMockStub(MockStub):
def Generate(self, request: generation.Request, **kwargs) -> Generator[generation.Answer, None, None]:
assert request.image.cai_parameters.model_metadata == \
generation._CAIPARAMETERS_MODELMETADATA.values_by_name['MODEL_METADATA_UNSPECIFIED'].number
assert request.image.c2pa_parameters.model_metadata == \
generation._C2PAPARAMETERS_MODELMETADATA.values_by_name['MODEL_METADATA_UNSPECIFIED'].number
return super().Generate(request, **kwargs)
api = Context(stub=CAIMockStub())
api = Context(stub=C2PAMockStub())
width, height = 512, 768
# sign_with_cai should default to false.
# sign_with_c2pa should default to false.
results = api.generate(prompts=["foo bar"], weights=[1.0], width=width, height=height)


Expand Down

0 comments on commit 3560d57

Please sign in to comment.