Skip to content

Commit 3a8788c

Browse files
authored
refactor: unify extra argument parsing (leejet#1540)
1 parent 449165c commit 3a8788c

4 files changed

Lines changed: 127 additions & 188 deletions

File tree

src/denoiser.hpp

Lines changed: 23 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -496,84 +496,26 @@ struct LTX2Scheduler : SigmaScheduler {
496496
parse_extra_sample_args(extra_sample_args);
497497
}
498498

499-
static std::string trim(std::string value) {
500-
const char* whitespace = " \t\r\n";
501-
size_t begin = value.find_first_not_of(whitespace);
502-
if (begin == std::string::npos) {
503-
return "";
504-
}
505-
size_t end = value.find_last_not_of(whitespace);
506-
return value.substr(begin, end - begin + 1);
507-
}
508-
509499
void parse_extra_sample_args(const char* extra_sample_args) {
510-
if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') {
511-
return;
512-
}
513-
514-
std::string raw(extra_sample_args);
515-
size_t start = 0;
516-
auto parse_arg = [&](const std::string& item) {
517-
std::string token = trim(item);
518-
if (token.empty()) {
519-
return;
520-
}
521-
size_t eq = token.find('=');
522-
if (eq == std::string::npos) {
523-
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
524-
return;
525-
}
526-
527-
std::string key = trim(token.substr(0, eq));
528-
std::string value = trim(token.substr(eq + 1));
529-
auto parse_float = [&](float* out) -> bool {
530-
try {
531-
size_t consumed = 0;
532-
float parsed = std::stof(value, &consumed);
533-
if (!trim(value.substr(consumed)).empty()) {
534-
return false;
535-
}
536-
*out = parsed;
537-
return true;
538-
} catch (const std::exception&) {
539-
return false;
500+
for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "ltx2 scheduler arg")) {
501+
if (key == "max_shift") {
502+
if (!parse_strict_float(value, max_shift)) {
503+
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
540504
}
541-
};
542-
try {
543-
if (key == "max_shift") {
544-
if (!parse_float(&max_shift)) {
545-
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
546-
}
547-
} else if (key == "base_shift") {
548-
if (!parse_float(&base_shift)) {
549-
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
550-
}
551-
} else if (key == "terminal") {
552-
if (!parse_float(&terminal)) {
553-
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
554-
}
555-
} else if (key == "stretch") {
556-
std::string v = value;
557-
std::transform(v.begin(), v.end(), v.begin(), [](unsigned char c) { return static_cast<char>(std::tolower(c)); });
558-
if (v == "1" || v == "true" || v == "yes" || v == "on") {
559-
stretch = true;
560-
} else if (v == "0" || v == "false" || v == "no" || v == "off") {
561-
stretch = false;
562-
} else {
563-
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
564-
}
565-
} else {
566-
LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str());
505+
} else if (key == "base_shift") {
506+
if (!parse_strict_float(value, base_shift)) {
507+
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
567508
}
568-
} catch (const std::exception&) {
569-
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
570-
}
571-
};
572-
573-
for (size_t pos = 0; pos <= raw.size(); ++pos) {
574-
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
575-
parse_arg(raw.substr(start, pos - start));
576-
start = pos + 1;
509+
} else if (key == "terminal") {
510+
if (!parse_strict_float(value, terminal)) {
511+
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
512+
}
513+
} else if (key == "stretch") {
514+
if (!parse_strict_bool(value, stretch)) {
515+
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
516+
}
517+
} else {
518+
LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str());
577519
}
578520
}
579521
}
@@ -1276,7 +1218,7 @@ static sd::Tensor<float> sample_dpmpp_2m_v2(denoise_cb_t model,
12761218
return x;
12771219
}
12781220

1279-
using SamplerExtraArgs = std::vector<std::pair<std::string, std::string>>;
1221+
using SamplerExtraArgs = KeyValueArgs;
12801222

12811223
static sd::Tensor<float> sample_lcm(denoise_cb_t model,
12821224
sd::Tensor<float> x,
@@ -1296,15 +1238,8 @@ static sd::Tensor<float> sample_lcm(denoise_cb_t model,
12961238

12971239
for (const auto& [key, value] : extra_sample_args) {
12981240
float parsed = 0.0f;
1299-
try {
1300-
size_t consumed = 0;
1301-
parsed = std::stof(value, &consumed);
1302-
if (trim(value.substr(consumed)).size() != 0) {
1303-
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", key.c_str());
1304-
continue;
1305-
}
1306-
} catch (const std::exception&) {
1307-
LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str());
1241+
if (!parse_strict_float(value, parsed)) {
1242+
LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str(), value.c_str());
13081243
continue;
13091244
}
13101245
if (key == "noise_clip_std") {
@@ -1861,15 +1796,8 @@ static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,
18611796

18621797
for (const auto& [key, value] : extra_sample_args) {
18631798
float parsed = 0.0f;
1864-
try {
1865-
size_t consumed = 0;
1866-
parsed = std::stof(value, &consumed);
1867-
if (trim(value.substr(consumed)).size() != 0) {
1868-
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
1869-
continue;
1870-
}
1871-
} catch (const std::exception&) {
1872-
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
1799+
if (!parse_strict_float(value, parsed)) {
1800+
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s=%s'", key.c_str(), value.c_str());
18731801
continue;
18741802
}
18751803
if (key == "gamma") {
@@ -1916,46 +1844,6 @@ static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,
19161844
return x;
19171845
}
19181846

1919-
static SamplerExtraArgs parse_sampler_args(const char* extra_sample_args) {
1920-
SamplerExtraArgs pairs;
1921-
1922-
if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') {
1923-
return pairs;
1924-
}
1925-
1926-
auto trim = [](std::string value) -> std::string {
1927-
const char* whitespace = " \t\r\n";
1928-
size_t begin = value.find_first_not_of(whitespace);
1929-
if (begin == std::string::npos) {
1930-
return "";
1931-
}
1932-
size_t end = value.find_last_not_of(whitespace);
1933-
return value.substr(begin, end - begin + 1);
1934-
};
1935-
1936-
std::string raw(extra_sample_args);
1937-
size_t start = 0;
1938-
1939-
for (size_t pos = 0; pos <= raw.size(); ++pos) {
1940-
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
1941-
std::string item = raw.substr(start, pos - start);
1942-
std::string token = trim(item);
1943-
1944-
if (!token.empty()) {
1945-
size_t eq = token.find('=');
1946-
if (eq != std::string::npos) {
1947-
std::string key = trim(token.substr(0, eq));
1948-
std::string value = trim(token.substr(eq + 1));
1949-
pairs.emplace_back(std::move(key), std::move(value));
1950-
}
1951-
}
1952-
start = pos + 1;
1953-
}
1954-
}
1955-
1956-
return pairs;
1957-
}
1958-
19591847
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
19601848
static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
19611849
denoise_cb_t model,
@@ -1965,7 +1853,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
19651853
float eta,
19661854
bool is_flow_denoiser,
19671855
const char* extra_sample_args) {
1968-
SamplerExtraArgs extra_args = parse_sampler_args(extra_sample_args);
1856+
SamplerExtraArgs extra_args = parse_key_value_args(extra_sample_args, "extra sample arg");
19691857
switch (method) {
19701858
case EULER_A_SAMPLE_METHOD:
19711859
return sample_euler_ancestral(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);

src/ltx_vae.hpp

Lines changed: 10 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,65 +1251,22 @@ struct LTXVideoVAE : public VAE {
12511251
temporal_tiling_enabled = enabled;
12521252
}
12531253

1254-
static std::string trim_tiling_arg(std::string value) {
1255-
const char* whitespace = " \t\r\n";
1256-
size_t begin = value.find_first_not_of(whitespace);
1257-
if (begin == std::string::npos) {
1258-
return "";
1259-
}
1260-
size_t end = value.find_last_not_of(whitespace);
1261-
return value.substr(begin, end - begin + 1);
1262-
}
1263-
1264-
static bool parse_tiling_int(const std::string& value, int& parsed) {
1265-
try {
1266-
size_t consumed = 0;
1267-
parsed = std::stoi(value, &consumed);
1268-
return trim_tiling_arg(value.substr(consumed)).empty();
1269-
} catch (...) {
1270-
return false;
1271-
}
1272-
}
1273-
12741254
void set_tiling_params(const sd_tiling_params_t& params) override {
12751255
temporal_tiling_enabled = params.temporal_tiling;
12761256
temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES;
12771257
temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP;
12781258

1279-
const char* extra_tiling_args = params.extra_tiling_args;
1280-
if (extra_tiling_args == nullptr || extra_tiling_args[0] == '\0') {
1281-
return;
1282-
}
1283-
1284-
std::string raw(extra_tiling_args);
1285-
size_t start = 0;
1286-
for (size_t pos = 0; pos <= raw.size(); ++pos) {
1287-
if (pos != raw.size() && raw[pos] != ',' && raw[pos] != ';') {
1288-
continue;
1289-
}
1290-
1291-
std::string token = trim_tiling_arg(raw.substr(start, pos - start));
1292-
if (!token.empty()) {
1293-
size_t eq = token.find('=');
1294-
if (eq == std::string::npos) {
1295-
LOG_WARN("ignoring malformed LTX VAE extra tiling arg '%s'", token.c_str());
1296-
} else {
1297-
std::string key = trim_tiling_arg(token.substr(0, eq));
1298-
std::string value = trim_tiling_arg(token.substr(eq + 1));
1299-
int parsed = 0;
1300-
if (!parse_tiling_int(value, parsed)) {
1301-
LOG_WARN("ignoring invalid LTX VAE extra tiling arg '%s=%s'", key.c_str(), value.c_str());
1302-
} else if (key == "temporal_tile_frames") {
1303-
temporal_tile_frames = std::max(1, parsed);
1304-
} else if (key == "temporal_tile_overlap") {
1305-
temporal_tile_overlap = std::max(0, parsed);
1306-
} else {
1307-
LOG_WARN("ignoring unknown LTX VAE extra tiling arg '%s'", key.c_str());
1308-
}
1309-
}
1259+
for (const auto& [key, value] : parse_key_value_args(params.extra_tiling_args, "LTX VAE extra tiling arg")) {
1260+
int parsed = 0;
1261+
if (!parse_strict_int(value, parsed)) {
1262+
LOG_WARN("ignoring invalid LTX VAE extra tiling arg '%s=%s'", key.c_str(), value.c_str());
1263+
} else if (key == "temporal_tile_frames") {
1264+
temporal_tile_frames = std::max(1, parsed);
1265+
} else if (key == "temporal_tile_overlap") {
1266+
temporal_tile_overlap = std::max(0, parsed);
1267+
} else {
1268+
LOG_WARN("ignoring unknown LTX VAE extra tiling arg '%s'", key.c_str());
13101269
}
1311-
1312-
start = pos + 1;
13131270
}
13141271
}
13151272

src/util.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
#include "util.h"
22
#include <algorithm>
3+
#include <cctype>
34
#include <cmath>
45
#include <codecvt>
56
#include <cstdarg>
7+
#include <exception>
68
#include <fstream>
79
#include <locale>
810
#include <regex>
@@ -406,6 +408,88 @@ std::vector<std::string> split_string(const std::string& str, char delimiter) {
406408
return result;
407409
}
408410

411+
KeyValueArgs parse_key_value_args(const char* args, const char* context) {
412+
KeyValueArgs pairs;
413+
414+
if (args == nullptr || args[0] == '\0') {
415+
return pairs;
416+
}
417+
418+
std::string raw(args);
419+
size_t start = 0;
420+
for (size_t pos = 0; pos <= raw.size(); ++pos) {
421+
if (pos != raw.size() && raw[pos] != ',' && raw[pos] != ';') {
422+
continue;
423+
}
424+
425+
std::string token = trim(raw.substr(start, pos - start));
426+
if (!token.empty()) {
427+
size_t eq = token.find('=');
428+
if (eq == std::string::npos) {
429+
const char* log_context = context ? context : "key=value arg";
430+
LOG_WARN("ignoring malformed %s '%s'", log_context, token.c_str());
431+
} else {
432+
std::string key = trim(token.substr(0, eq));
433+
std::string value = trim(token.substr(eq + 1));
434+
pairs.emplace_back(std::move(key), std::move(value));
435+
}
436+
}
437+
438+
start = pos + 1;
439+
}
440+
441+
return pairs;
442+
}
443+
444+
KeyValueArgs parse_key_value_args(const std::string& args, const char* context) {
445+
return parse_key_value_args(args.c_str(), context);
446+
}
447+
448+
bool parse_strict_float(const std::string& text, float& value) {
449+
try {
450+
size_t consumed = 0;
451+
float parsed = std::stof(text, &consumed);
452+
if (!trim(text.substr(consumed)).empty()) {
453+
return false;
454+
}
455+
value = parsed;
456+
return true;
457+
} catch (const std::exception&) {
458+
return false;
459+
}
460+
}
461+
462+
bool parse_strict_int(const std::string& text, int& value) {
463+
try {
464+
size_t consumed = 0;
465+
int parsed = std::stoi(text, &consumed);
466+
if (!trim(text.substr(consumed)).empty()) {
467+
return false;
468+
}
469+
value = parsed;
470+
return true;
471+
} catch (const std::exception&) {
472+
return false;
473+
}
474+
}
475+
476+
bool parse_strict_bool(const std::string& text, bool& value) {
477+
std::string lowered = trim(text);
478+
std::transform(lowered.begin(), lowered.end(), lowered.begin(), [](unsigned char c) {
479+
return static_cast<char>(std::tolower(c));
480+
});
481+
482+
if (lowered == "1" || lowered == "true" || lowered == "yes" || lowered == "on") {
483+
value = true;
484+
return true;
485+
}
486+
if (lowered == "0" || lowered == "false" || lowered == "no" || lowered == "off") {
487+
value = false;
488+
return true;
489+
}
490+
return false;
491+
}
492+
409493
static std::string build_progress_bar(int step, int steps) {
410494
std::string progress = " |";
411495
int max_progress = 50;

0 commit comments

Comments
 (0)