@@ -1821,16 +1821,23 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
18211821 } else if (sd_ctx->sd ->version == VERSION_FLEX_2 ) {
18221822 mask_channels = 1 + init_latent->ne [2 ];
18231823 }
1824- ggml_tensor* masked_img = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32 , width, height, 3 , 1 );
1825- // Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
1826- sd_image_to_tensor (init_image.data , init_img);
1827- sd_apply_mask (init_img, mask_img, masked_img);
18281824 ggml_tensor* masked_latent = NULL ;
1829- if (!sd_ctx->sd ->use_tiny_autoencoder ) {
1830- ggml_tensor* moments = sd_ctx->sd ->encode_first_stage (work_ctx, masked_img);
1831- masked_latent = sd_ctx->sd ->get_first_stage_encoding (work_ctx, moments);
1825+ if (sd_ctx->sd ->version != VERSION_FLEX_2 ) {
1826+ // most inpaint models mask before vae
1827+ ggml_tensor* masked_img = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32 , width, height, 3 , 1 );
1828+ // Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
1829+ sd_image_to_tensor (init_image.data , init_img);
1830+ sd_apply_mask (init_img, mask_img, masked_img);
1831+ if (!sd_ctx->sd ->use_tiny_autoencoder ) {
1832+ ggml_tensor* moments = sd_ctx->sd ->encode_first_stage (work_ctx, masked_img);
1833+ masked_latent = sd_ctx->sd ->get_first_stage_encoding (work_ctx, moments);
1834+ } else {
1835+ masked_latent = sd_ctx->sd ->encode_first_stage (work_ctx, masked_img);
1836+ }
18321837 } else {
1833- masked_latent = sd_ctx->sd ->encode_first_stage (work_ctx, masked_img);
1838+ // mask after vae
1839+ masked_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32 , init_latent->ne [0 ], init_latent->ne [1 ], init_latent->ne [2 ], 1 );
1840+ sd_apply_mask (init_latent, mask_img, masked_latent, 0 .);
18341841 }
18351842 concat_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32 , masked_latent->ne [0 ], masked_latent->ne [1 ], mask_channels + masked_latent->ne [2 ], 1 );
18361843 for (int ix = 0 ; ix < masked_latent->ne [0 ]; ix++) {
0 commit comments