@@ -496,84 +496,26 @@ struct LTX2Scheduler : SigmaScheduler {
496496 parse_extra_sample_args (extra_sample_args);
497497 }
498498
499- static std::string trim (std::string value) {
500- const char * whitespace = " \t\r\n " ;
501- size_t begin = value.find_first_not_of (whitespace);
502- if (begin == std::string::npos) {
503- return " " ;
504- }
505- size_t end = value.find_last_not_of (whitespace);
506- return value.substr (begin, end - begin + 1 );
507- }
508-
509499 void parse_extra_sample_args (const char * extra_sample_args) {
510- if (extra_sample_args == nullptr || extra_sample_args[0 ] == ' \0 ' ) {
511- return ;
512- }
513-
514- std::string raw (extra_sample_args);
515- size_t start = 0 ;
516- auto parse_arg = [&](const std::string& item) {
517- std::string token = trim (item);
518- if (token.empty ()) {
519- return ;
520- }
521- size_t eq = token.find (' =' );
522- if (eq == std::string::npos) {
523- LOG_WARN (" ignoring invalid ltx2 scheduler arg '%s'" , token.c_str ());
524- return ;
525- }
526-
527- std::string key = trim (token.substr (0 , eq));
528- std::string value = trim (token.substr (eq + 1 ));
529- auto parse_float = [&](float * out) -> bool {
530- try {
531- size_t consumed = 0 ;
532- float parsed = std::stof (value, &consumed);
533- if (!trim (value.substr (consumed)).empty ()) {
534- return false ;
535- }
536- *out = parsed;
537- return true ;
538- } catch (const std::exception&) {
539- return false ;
500+ for (const auto & [key, value] : parse_key_value_args (extra_sample_args, " ltx2 scheduler arg" )) {
501+ if (key == " max_shift" ) {
502+ if (!parse_strict_float (value, max_shift)) {
503+ LOG_WARN (" ignoring invalid ltx2 scheduler arg '%s=%s'" , key.c_str (), value.c_str ());
540504 }
541- };
542- try {
543- if (key == " max_shift" ) {
544- if (!parse_float (&max_shift)) {
545- LOG_WARN (" ignoring invalid ltx2 scheduler arg '%s'" , token.c_str ());
546- }
547- } else if (key == " base_shift" ) {
548- if (!parse_float (&base_shift)) {
549- LOG_WARN (" ignoring invalid ltx2 scheduler arg '%s'" , token.c_str ());
550- }
551- } else if (key == " terminal" ) {
552- if (!parse_float (&terminal)) {
553- LOG_WARN (" ignoring invalid ltx2 scheduler arg '%s'" , token.c_str ());
554- }
555- } else if (key == " stretch" ) {
556- std::string v = value;
557- std::transform (v.begin (), v.end (), v.begin (), [](unsigned char c) { return static_cast <char >(std::tolower (c)); });
558- if (v == " 1" || v == " true" || v == " yes" || v == " on" ) {
559- stretch = true ;
560- } else if (v == " 0" || v == " false" || v == " no" || v == " off" ) {
561- stretch = false ;
562- } else {
563- LOG_WARN (" ignoring invalid ltx2 scheduler arg '%s'" , token.c_str ());
564- }
565- } else {
566- LOG_WARN (" ignoring unknown ltx2 scheduler arg '%s'" , key.c_str ());
505+ } else if (key == " base_shift" ) {
506+ if (!parse_strict_float (value, base_shift)) {
507+ LOG_WARN (" ignoring invalid ltx2 scheduler arg '%s=%s'" , key.c_str (), value.c_str ());
567508 }
568- } catch (const std::exception&) {
569- LOG_WARN (" ignoring invalid ltx2 scheduler arg '%s'" , token.c_str ());
570- }
571- };
572-
573- for (size_t pos = 0 ; pos <= raw.size (); ++pos) {
574- if (pos == raw.size () || raw[pos] == ' ,' || raw[pos] == ' ;' ) {
575- parse_arg (raw.substr (start, pos - start));
576- start = pos + 1 ;
509+ } else if (key == " terminal" ) {
510+ if (!parse_strict_float (value, terminal)) {
511+ LOG_WARN (" ignoring invalid ltx2 scheduler arg '%s=%s'" , key.c_str (), value.c_str ());
512+ }
513+ } else if (key == " stretch" ) {
514+ if (!parse_strict_bool (value, stretch)) {
515+ LOG_WARN (" ignoring invalid ltx2 scheduler arg '%s=%s'" , key.c_str (), value.c_str ());
516+ }
517+ } else {
518+ LOG_WARN (" ignoring unknown ltx2 scheduler arg '%s'" , key.c_str ());
577519 }
578520 }
579521 }
@@ -1276,7 +1218,7 @@ static sd::Tensor<float> sample_dpmpp_2m_v2(denoise_cb_t model,
12761218 return x;
12771219}
12781220
1279- using SamplerExtraArgs = std::vector<std::pair<std::string, std::string>> ;
1221+ using SamplerExtraArgs = KeyValueArgs ;
12801222
12811223static sd::Tensor<float > sample_lcm (denoise_cb_t model,
12821224 sd::Tensor<float > x,
@@ -1296,15 +1238,8 @@ static sd::Tensor<float> sample_lcm(denoise_cb_t model,
12961238
12971239 for (const auto & [key, value] : extra_sample_args) {
12981240 float parsed = 0 .0f ;
1299- try {
1300- size_t consumed = 0 ;
1301- parsed = std::stof (value, &consumed);
1302- if (trim (value.substr (consumed)).size () != 0 ) {
1303- LOG_WARN (" ignoring invalid lcm extra sample arg '%s'" , key.c_str ());
1304- continue ;
1305- }
1306- } catch (const std::exception&) {
1307- LOG_WARN (" ignoring invalid lcm extra sample arg '%s=%s'" , key.c_str ());
1241+ if (!parse_strict_float (value, parsed)) {
1242+ LOG_WARN (" ignoring invalid lcm extra sample arg '%s=%s'" , key.c_str (), value.c_str ());
13081243 continue ;
13091244 }
13101245 if (key == " noise_clip_std" ) {
@@ -1861,15 +1796,8 @@ static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,
18611796
18621797 for (const auto & [key, value] : extra_sample_args) {
18631798 float parsed = 0 .0f ;
1864- try {
1865- size_t consumed = 0 ;
1866- parsed = std::stof (value, &consumed);
1867- if (trim (value.substr (consumed)).size () != 0 ) {
1868- LOG_WARN (" ignoring invalid euler_ge extra sample arg '%s'" , key.c_str ());
1869- continue ;
1870- }
1871- } catch (const std::exception&) {
1872- LOG_WARN (" ignoring invalid euler_ge extra sample arg '%s'" , key.c_str ());
1799+ if (!parse_strict_float (value, parsed)) {
1800+ LOG_WARN (" ignoring invalid euler_ge extra sample arg '%s=%s'" , key.c_str (), value.c_str ());
18731801 continue ;
18741802 }
18751803 if (key == " gamma" ) {
@@ -1916,46 +1844,6 @@ static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,
19161844 return x;
19171845}
19181846
1919- static SamplerExtraArgs parse_sampler_args (const char * extra_sample_args) {
1920- SamplerExtraArgs pairs;
1921-
1922- if (extra_sample_args == nullptr || extra_sample_args[0 ] == ' \0 ' ) {
1923- return pairs;
1924- }
1925-
1926- auto trim = [](std::string value) -> std::string {
1927- const char * whitespace = " \t\r\n " ;
1928- size_t begin = value.find_first_not_of (whitespace);
1929- if (begin == std::string::npos) {
1930- return " " ;
1931- }
1932- size_t end = value.find_last_not_of (whitespace);
1933- return value.substr (begin, end - begin + 1 );
1934- };
1935-
1936- std::string raw (extra_sample_args);
1937- size_t start = 0 ;
1938-
1939- for (size_t pos = 0 ; pos <= raw.size (); ++pos) {
1940- if (pos == raw.size () || raw[pos] == ' ,' || raw[pos] == ' ;' ) {
1941- std::string item = raw.substr (start, pos - start);
1942- std::string token = trim (item);
1943-
1944- if (!token.empty ()) {
1945- size_t eq = token.find (' =' );
1946- if (eq != std::string::npos) {
1947- std::string key = trim (token.substr (0 , eq));
1948- std::string value = trim (token.substr (eq + 1 ));
1949- pairs.emplace_back (std::move (key), std::move (value));
1950- }
1951- }
1952- start = pos + 1 ;
1953- }
1954- }
1955-
1956- return pairs;
1957- }
1958-
19591847// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
19601848static sd::Tensor<float > sample_k_diffusion (sample_method_t method,
19611849 denoise_cb_t model,
@@ -1965,7 +1853,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
19651853 float eta,
19661854 bool is_flow_denoiser,
19671855 const char * extra_sample_args) {
1968- SamplerExtraArgs extra_args = parse_sampler_args (extra_sample_args);
1856+ SamplerExtraArgs extra_args = parse_key_value_args (extra_sample_args, " extra sample arg " );
19691857 switch (method) {
19701858 case EULER_A_SAMPLE_METHOD :
19711859 return sample_euler_ancestral (model, std::move (x), sigmas, rng, is_flow_denoiser, eta);
0 commit comments