From c77e525243dc233b88edaefabbfc5f19facf4044 Mon Sep 17 00:00:00 2001 From: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> Date: Tue, 17 Dec 2024 19:59:56 -0800 Subject: [PATCH] Better JSON type mismatch errors (#1147) This changes the JSON parsing to use a std::variant so there just a single OnValue handler vs OnString/OnNumber/OnBool/OnNull. Previously a mismatched type would say `JSON Error: Unknown value: name at line 3 index 19` or it would say `JSON Error: Unknown value: name` if the name was known but the type of its value was wrong (example: https://github.com/microsoft/onnxruntime-genai/issues/1146). Now it'll give a much better error message, showing first the full path of the field being parsed, and then saying exactly how the types mismatch: `JSON Error: model:type - Expected a number but saw a string at line 3 index 19` --- src/config.cpp | 257 +++++++++++++++++++++---------------------------- src/json.cpp | 29 +++--- src/json.h | 24 +++-- 3 files changed, 143 insertions(+), 167 deletions(-) diff --git a/src/config.cpp b/src/config.cpp index 0c6de4d69..459da0d68 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -22,8 +22,8 @@ ONNXTensorElementDataType TranslateTensorType(std::string_view value) { struct ProviderOptions_Element : JSON::Element { explicit ProviderOptions_Element(Config::ProviderOptions& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { - v_.options.emplace_back(name, value); + void OnValue(std::string_view name, JSON::Value value) override { + v_.options.emplace_back(name, JSON::Get(value)); } private: @@ -65,45 +65,35 @@ struct ProviderOptionsArray_Element : JSON::Element { struct SessionOptions_Element : JSON::Element { explicit SessionOptions_Element(Config::SessionOptions& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "log_id") - v_.log_id = value; + v_.log_id = JSON::Get(value); else if (name == "enable_profiling") - v_.enable_profiling = value; + v_.enable_profiling = JSON::Get(value); else if (name == "ep_context_embed_mode") - v_.ep_context_embed_mode = value; + v_.ep_context_embed_mode = JSON::Get(value); else if (name == "ep_context_file_path") - v_.ep_context_file_path = value; - else - throw JSON::unknown_value_error{}; - } - - void OnNumber(std::string_view name, double value) override { - if (name == "intra_op_num_threads") - v_.intra_op_num_threads = static_cast(value); + v_.ep_context_file_path = JSON::Get(value); + else if (name == "intra_op_num_threads") + v_.intra_op_num_threads = static_cast(JSON::Get(value)); else if (name == "inter_op_num_threads") - v_.inter_op_num_threads = static_cast(value); + v_.inter_op_num_threads = static_cast(JSON::Get(value)); else if (name == "log_severity_level") - v_.log_severity_level = static_cast(value); - else - throw JSON::unknown_value_error{}; - } - - void OnBool(std::string_view name, bool value) override { - if (name == "enable_cpu_mem_arena") - v_.enable_cpu_mem_arena = value; + v_.log_severity_level = static_cast(JSON::Get(value)); + else if (name == "enable_cpu_mem_arena") + v_.enable_cpu_mem_arena = JSON::Get(value); else if (name == "enable_mem_pattern") - v_.enable_mem_pattern = value; + v_.enable_mem_pattern = JSON::Get(value); else if (name == "disable_cpu_ep_fallback") - v_.disable_cpu_ep_fallback = value; + v_.disable_cpu_ep_fallback = JSON::Get(value); else if (name == "disable_quant_qdq") - v_.disable_quant_qdq = value; + v_.disable_quant_qdq = JSON::Get(value); else if (name == "enable_quant_qdq_cleanup") - v_.enable_quant_qdq_cleanup = value; + v_.enable_quant_qdq_cleanup = JSON::Get(value); else if (name == "ep_context_enable") - v_.ep_context_enable = value; + v_.ep_context_enable = JSON::Get(value); else if (name == "use_env_allocators") - v_.use_env_allocators = value; + v_.use_env_allocators = JSON::Get(value); else throw JSON::unknown_value_error{}; } @@ -122,9 +112,9 @@ struct SessionOptions_Element : JSON::Element { struct EncoderDecoderInit_Element : JSON::Element { explicit EncoderDecoderInit_Element(Config::Model::EncoderDecoderInit& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = value; + v_.filename = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -136,29 +126,29 @@ struct EncoderDecoderInit_Element : JSON::Element { struct Inputs_Element : JSON::Element { explicit Inputs_Element(Config::Model::Decoder::Inputs& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "input_ids") { - v_.input_ids = value; + v_.input_ids = JSON::Get(value); } else if (name == "inputs_embeds") { - v_.embeddings = value; + v_.embeddings = JSON::Get(value); } else if (name == "position_ids") { - v_.position_ids = value; + v_.position_ids = JSON::Get(value); } else if (name == "attention_mask") { - v_.attention_mask = value; + v_.attention_mask = JSON::Get(value); } else if (name == "past_key_names") { - v_.past_key_names = value; + v_.past_key_names = JSON::Get(value); } else if (name == "past_value_names") { - v_.past_value_names = value; + v_.past_value_names = JSON::Get(value); } else if (name == "past_names") { - v_.past_names = value; + v_.past_names = JSON::Get(value); } else if (name == "cross_past_key_names") { - v_.cross_past_key_names = value; + v_.cross_past_key_names = JSON::Get(value); } else if (name == "cross_past_value_names") { - v_.cross_past_value_names = value; + v_.cross_past_value_names = JSON::Get(value); } else if (name == "current_sequence_length") { - v_.current_sequence_length = value; + v_.current_sequence_length = JSON::Get(value); } else if (name == "past_sequence_length") { - v_.past_sequence_length = value; + v_.past_sequence_length = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -170,19 +160,19 @@ struct Inputs_Element : JSON::Element { struct Outputs_Element : JSON::Element { explicit Outputs_Element(Config::Model::Decoder::Outputs& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "logits") { - v_.logits = value; + v_.logits = JSON::Get(value); } else if (name == "present_key_names") { - v_.present_key_names = value; + v_.present_key_names = JSON::Get(value); } else if (name == "present_value_names") { - v_.present_value_names = value; + v_.present_value_names = JSON::Get(value); } else if (name == "present_names") { - v_.present_names = value; + v_.present_names = JSON::Get(value); } else if (name == "cross_present_key_names") { - v_.cross_present_key_names = value; + v_.cross_present_key_names = JSON::Get(value); } else if (name == "cross_present_value_names") { - v_.cross_present_value_names = value; + v_.cross_present_value_names = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -194,8 +184,8 @@ struct Outputs_Element : JSON::Element { struct StringArray_Element : JSON::Element { explicit StringArray_Element(std::vector& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { - v_.push_back(std::string(value)); + void OnValue(std::string_view name, JSON::Value value) override { + v_.push_back(std::string{JSON::Get(value)}); } private: @@ -205,8 +195,8 @@ struct StringArray_Element : JSON::Element { struct StringStringMap_Element : JSON::Element { explicit StringStringMap_Element(std::unordered_map& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { - v_[std::string(name)] = std::string(value); + void OnValue(std::string_view name, JSON::Value value) override { + v_[std::string(name)] = std::string(JSON::Get(value)); } private: @@ -216,18 +206,13 @@ struct StringStringMap_Element : JSON::Element { struct PipelineModel_Element : JSON::Element { explicit PipelineModel_Element(Config::Model::Decoder::PipelineModel& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = value; - } else - throw JSON::unknown_value_error{}; - } - - void OnBool(std::string_view name, bool value) override { - if (name == "run_on_prompt") { - v_.run_on_prompt = value; + v_.filename = JSON::Get(value); + } else if (name == "run_on_prompt") { + v_.run_on_prompt = JSON::Get(value); } else if (name == "run_on_token_gen") { - v_.run_on_token_gen = value; + v_.run_on_token_gen = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -289,24 +274,19 @@ struct Pipeline_Element : JSON::Element { struct Decoder_Element : JSON::Element { explicit Decoder_Element(Config::Model::Decoder& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = value; - } else - throw JSON::unknown_value_error{}; - } - - void OnNumber(std::string_view name, double value) override { - if (name == "hidden_size") { - v_.hidden_size = static_cast(value); + v_.filename = JSON::Get(value); + } else if (name == "hidden_size") { + v_.hidden_size = static_cast(JSON::Get(value)); } else if (name == "num_attention_heads") { - v_.num_attention_heads = static_cast(value); + v_.num_attention_heads = static_cast(JSON::Get(value)); } else if (name == "num_key_value_heads") { - v_.num_key_value_heads = static_cast(value); + v_.num_key_value_heads = static_cast(JSON::Get(value)); } else if (name == "num_hidden_layers") { - v_.num_hidden_layers = static_cast(value); + v_.num_hidden_layers = static_cast(JSON::Get(value)); } else if (name == "head_size") { - v_.head_size = static_cast(value); + v_.head_size = static_cast(JSON::Get(value)); } else throw JSON::unknown_value_error{}; } @@ -341,11 +321,11 @@ struct Decoder_Element : JSON::Element { struct VisionInputs_Element : JSON::Element { explicit VisionInputs_Element(Config::Model::Vision::Inputs& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "pixel_values") { - v_.pixel_values = value; + v_.pixel_values = JSON::Get(value); } else if (name == "image_sizes") { - v_.image_sizes = value; + v_.image_sizes = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -357,9 +337,9 @@ struct VisionInputs_Element : JSON::Element { struct VisionOutputs_Element : JSON::Element { explicit VisionOutputs_Element(Config::Model::Vision::Outputs& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "image_features") { - v_.image_features = value; + v_.image_features = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -371,9 +351,9 @@ struct VisionOutputs_Element : JSON::Element { struct Vision_Element : JSON::Element { explicit Vision_Element(Config::Model::Vision& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = value; + v_.filename = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -396,8 +376,8 @@ struct Vision_Element : JSON::Element { struct Eos_Array_Element : JSON::Element { explicit Eos_Array_Element(Config::Model& v) : v_{v} {} - void OnNumber(std::string_view name, double value) override { - v_.eos_token_ids.push_back(static_cast(value)); + void OnValue(std::string_view name, JSON::Value value) override { + v_.eos_token_ids.push_back(static_cast(JSON::Get(value))); } void OnComplete(bool empty) override { @@ -419,11 +399,11 @@ struct Eos_Array_Element : JSON::Element { struct EmbeddingInputs_Element : JSON::Element { explicit EmbeddingInputs_Element(Config::Model::Embedding::Inputs& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "input_ids") { - v_.input_ids = value; + v_.input_ids = JSON::Get(value); } else if (name == "image_features") { - v_.image_features = value; + v_.image_features = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -435,9 +415,9 @@ struct EmbeddingInputs_Element : JSON::Element { struct EmbeddingOutputs_Element : JSON::Element { explicit EmbeddingOutputs_Element(Config::Model::Embedding::Outputs& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "inputs_embeds") { - v_.embeddings = value; + v_.embeddings = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -449,9 +429,9 @@ struct EmbeddingOutputs_Element : JSON::Element { struct Embedding_Element : JSON::Element { explicit Embedding_Element(Config::Model::Embedding& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "filename") { - v_.filename = value; + v_.filename = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -474,20 +454,20 @@ struct Embedding_Element : JSON::Element { struct PromptTemplates_Element : JSON::Element { explicit PromptTemplates_Element(std::optional& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { // if one of templates is given in json, then any non-specified template will be default "{Content}" if (name == "assistant") { EnsureAvailable(); - v_->assistant = value; + v_->assistant = JSON::Get(value); } else if (name == "prompt") { EnsureAvailable(); - v_->prompt = value; + v_->prompt = JSON::Get(value); } else if (name == "system") { EnsureAvailable(); - v_->system = value; + v_->system = JSON::Get(value); } else if (name == "user") { EnsureAvailable(); - v_->user = value; + v_->user = JSON::Get(value); } else { throw JSON::unknown_value_error{}; } @@ -506,28 +486,23 @@ struct PromptTemplates_Element : JSON::Element { struct Model_Element : JSON::Element { explicit Model_Element(Config::Model& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "type") { - v_.type = value; - } else - throw JSON::unknown_value_error{}; - } - - void OnNumber(std::string_view name, double value) override { - if (name == "vocab_size") { - v_.vocab_size = static_cast(value); + v_.type = JSON::Get(value); + } else if (name == "vocab_size") { + v_.vocab_size = static_cast(JSON::Get(value)); } else if (name == "context_length") { - v_.context_length = static_cast(value); + v_.context_length = static_cast(JSON::Get(value)); } else if (name == "pad_token_id") { - v_.pad_token_id = static_cast(value); + v_.pad_token_id = static_cast(JSON::Get(value)); } else if (name == "eos_token_id") { - v_.eos_token_id = static_cast(value); + v_.eos_token_id = static_cast(JSON::Get(value)); } else if (name == "bos_token_id") { - v_.bos_token_id = static_cast(value); + v_.bos_token_id = static_cast(JSON::Get(value)); } else if (name == "decoder_start_token_id") { - v_.decoder_start_token_id = static_cast(value); + v_.decoder_start_token_id = static_cast(JSON::Get(value)); } else if (name == "sep_token_id") { - v_.sep_token_id = static_cast(value); + v_.sep_token_id = static_cast(JSON::Get(value)); } else throw JSON::unknown_value_error{}; } @@ -570,50 +545,41 @@ struct Model_Element : JSON::Element { struct Search_Element : JSON::Element { explicit Search_Element(Config::Search& v) : v_{v} {} - void OnString(std::string_view name, std::string_view value) override { - throw JSON::unknown_value_error{}; - } - - void OnNumber(std::string_view name, double value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "min_length") { - v_.min_length = static_cast(value); + v_.min_length = static_cast(JSON::Get(value)); } else if (name == "max_length") { - v_.max_length = static_cast(value); + v_.max_length = static_cast(JSON::Get(value)); } else if (name == "batch_size") { - v_.batch_size = static_cast(value); + v_.batch_size = static_cast(JSON::Get(value)); } else if (name == "num_beams") { - v_.num_beams = static_cast(value); + v_.num_beams = static_cast(JSON::Get(value)); } else if (name == "num_return_sequences") { - v_.num_return_sequences = static_cast(value); + v_.num_return_sequences = static_cast(JSON::Get(value)); } else if (name == "top_k") { - v_.top_k = static_cast(value); + v_.top_k = static_cast(JSON::Get(value)); } else if (name == "top_p") { - v_.top_p = static_cast(value); + v_.top_p = static_cast(JSON::Get(value)); } else if (name == "temperature") { - v_.temperature = static_cast(value); + v_.temperature = static_cast(JSON::Get(value)); } else if (name == "repetition_penalty") { - v_.repetition_penalty = static_cast(value); + v_.repetition_penalty = static_cast(JSON::Get(value)); } else if (name == "length_penalty") { - v_.length_penalty = static_cast(value); + v_.length_penalty = static_cast(JSON::Get(value)); } else if (name == "no_repeat_ngram_size") { - v_.no_repeat_ngram_size = static_cast(value); + v_.no_repeat_ngram_size = static_cast(JSON::Get(value)); } else if (name == "diversity_penalty") { - v_.diversity_penalty = static_cast(value); + v_.diversity_penalty = static_cast(JSON::Get(value)); } else if (name == "length_penalty") { - v_.length_penalty = static_cast(value); + v_.length_penalty = static_cast(JSON::Get(value)); } else if (name == "random_seed") { - v_.random_seed = static_cast(value); - } else - throw JSON::unknown_value_error{}; - } - - void OnBool(std::string_view name, bool value) override { - if (name == "do_sample") { - v_.do_sample = value; + v_.random_seed = static_cast(JSON::Get(value)); + } else if (name == "do_sample") { + v_.do_sample = JSON::Get(value); } else if (name == "past_present_share_buffer") { - v_.past_present_share_buffer = value; + v_.past_present_share_buffer = JSON::Get(value); } else if (name == "early_stopping") { - v_.early_stopping = value; + v_.early_stopping = JSON::Get(value); } else throw JSON::unknown_value_error{}; } @@ -623,11 +589,11 @@ struct Search_Element : JSON::Element { }; void SetSearchNumber(Config::Search& search, std::string_view name, double value) { - Search_Element(search).OnNumber(name, value); + Search_Element(search).OnValue(name, value); } void SetSearchBool(Config::Search& search, std::string_view name, bool value) { - Search_Element(search).OnBool(name, value); + Search_Element(search).OnValue(name, value); } void ClearProviders(Config& config) { @@ -663,10 +629,7 @@ bool IsCudaGraphEnabled(Config::SessionOptions& session_options) { struct Root_Element : JSON::Element { explicit Root_Element(Config& config) : config_{config} {} - void OnString(std::string_view name, std::string_view value) override { - } - - void OnNumber(std::string_view name, double value) override { + void OnValue(std::string_view name, JSON::Value value) override { } Element& OnObject(std::string_view name) override { diff --git a/src/json.cpp b/src/json.cpp index 4d4d0aa91..bd2c2aefb 100644 --- a/src/json.cpp +++ b/src/json.cpp @@ -6,14 +6,8 @@ #include namespace JSON { - -Element& Element::OnArray(std::string_view /*name*/) { - throw unknown_value_error{}; -} - -Element& Element::OnObject(std::string_view /*name*/) { - throw unknown_value_error{}; -} +static constexpr const char* value_names[] = {"string", "number", "bool", "null"}; +static_assert(std::size(value_names) == std::variant_size_v); struct JSON { JSON(Element& element, std::string_view document); @@ -148,34 +142,41 @@ void JSON::Parse_Value(Element& element, std::string_view name) { Parse_Array(element_array); } break; case '"': { - element.OnString(name, Parse_String()); + element.OnValue(name, Parse_String()); } break; case 't': if (Skip("rue")) { - element.OnBool(name, true); + element.OnValue(name, true); } break; case 'f': if (Skip("alse")) { - element.OnBool(name, false); + element.OnValue(name, false); } break; case 'n': if (Skip("ull")) { - element.OnNull(name); + element.OnValue(name, nullptr); } break; default: if (c >= '0' && c <= '9' || c == '-') { --current_; - element.OnNumber(name, Parse_Number()); + element.OnValue(name, Parse_Number()); } else throw unknown_value_error{}; break; } } catch (const unknown_value_error&) { - throw std::runtime_error("Unknown value: " + std::string(name)); + throw std::runtime_error(" Unknown value \"" + std::string(name) + "\""); + } catch (const type_mismatch& e) { + throw std::runtime_error(std::string(name) + " - Expected a " + std::string(value_names[e.expected]) + " but saw a " + std::string(value_names[e.seen])); + } catch (const std::runtime_error& e) { + if (!name.empty()) + throw std::runtime_error(std::string(name) + ":" + e.what()); + throw; } + Parse_Whitespace(); } diff --git a/src/json.h b/src/json.h index 58bc16319..b489a2ad3 100644 --- a/src/json.h +++ b/src/json.h @@ -9,17 +9,29 @@ // namespace JSON { struct unknown_value_error : std::exception {}; // Throw this from any Element callback to throw a std::runtime error reporting the unknown value name +struct type_mismatch { // When a file has one type, but we're expecting another type. "seen" & "expected" are indices into the Value std::variant below + size_t seen, expected; +}; + +using Value = std::variant; + +// To see descriptive errors when types don't match, use this instead of std::get +template +T Get(Value& var) { + try { + return std::get(var); + } catch (const std::bad_variant_access&) { + throw type_mismatch{var.index(), Value{T{}}.index()}; + } +} struct Element { virtual void OnComplete(bool empty) {} // Called when parsing for this element is finished (empty is true when it's an empty element) - virtual void OnString(std::string_view name, std::string_view value) { throw unknown_value_error{}; } - virtual void OnNumber(std::string_view name, double value) { throw unknown_value_error{}; } - virtual void OnBool(std::string_view name, bool value) { throw unknown_value_error{}; } - virtual void OnNull(std::string_view name) { throw unknown_value_error{}; } + virtual void OnValue(std::string_view name, Value value) { throw unknown_value_error{}; } - virtual Element& OnArray(std::string_view name); - virtual Element& OnObject(std::string_view name); + virtual Element& OnArray(std::string_view name) { throw unknown_value_error{}; } + virtual Element& OnObject(std::string_view name) { throw unknown_value_error{}; } }; void Parse(Element& element, std::string_view document);