Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry-pick PR for 0.3.0-rc2 #528

Merged
merged 6 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pipelines/stages/jobs/steps/nuget-win-step.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ steps:
DisplayName: 'ESRP - Sign C# dlls'
Pattern: '*OnnxRuntimeGenAI*.dll'
- powershell: |
$VERSION = '0.3.0-rc1'
$VERSION = '0.3.0-rc2'
nuget.exe pack Microsoft.ML.OnnxRuntimeGenAI.nuspec `
-Prop version=$VERSION `
-Prop genai_nuget_ext=$(genai_nuget_ext) `
Expand Down
2 changes: 1 addition & 1 deletion VERSION_INFO
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.0-rc1
0.3.0-rc2
2 changes: 2 additions & 0 deletions examples/python/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def run(args: argparse.Namespace):
for _ in range(3):
print()

# Delete the generator to free the captured graph before creating another one
del generator

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down
2 changes: 1 addition & 1 deletion src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ struct RootObject_Element : JSON::Element {
};

void ParseConfig(const fs::path& filename, Config& config) {
std::ifstream file(filename, std::ios::binary | std::ios::ate);
std::ifstream file = filename.open(std::ios::binary | std::ios::ate);
if (!file.is_open()) {
throw std::runtime_error("Error opening " + filename.string());
}
Expand Down
153 changes: 147 additions & 6 deletions src/filesystem.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,152 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// TODO(baijumeswani): Remove experimental when packaging pipeline can use GCC > 8
#ifdef USE_EXPERIMENTAL_FILESYSTEM
#include <experimental/filesystem>
namespace fs = std::experimental::filesystem;
#pragma once

#ifdef _WIN32
#ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers
#endif

#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <Windows.h>
#define ENABLE_INTSAFE_SIGNED_FUNCTIONS // Only unsigned intsafe math/casts available without this def
#include <intsafe.h>
#include <tchar.h>
#endif // _WIN32

#include <sys/stat.h>

#include <string>
#include <fstream>

namespace fs {

class path {
public:
path() = default;
path(const std::string& path) : path_(path) {
#ifdef _WIN32
wpath_ = to_wstring();
#endif
};

static constexpr char separator =
#ifdef _WIN32
'\\';
#else
'/';
#endif

using ios_base = std::ios_base;
std::ifstream open(ios_base::openmode mode = ios_base::in) const {
// if Windows, need to convert the string to UTF-16
#ifdef _WIN32
return std::ifstream(wpath_, mode);
#else
return std::ifstream(path_, mode);
#endif // _WIN32
}

std::ofstream open_for_write(ios_base::openmode mode = ios_base::out) const {
// if Windows, need to convert the string to UTF-16
#ifdef _WIN32
return std::ofstream(wpath_, mode);
#else
return std::ofstream(path_, mode);
#endif // _WIN32
}

const std::string& string() const {
return path_;
}

path join(const std::string& path) const {
return path_ + separator + path;
}

path operator/(const std::string& path) const {
return join(path);
}

path operator/(const path& path) {
return join(path.path_);
}

#ifdef _WIN32
const wchar_t* c_str() const {
return wpath_.c_str();
}
#else
#include <filesystem>
namespace fs = std::filesystem;
const char* c_str() const {
return path_.c_str();
}
#endif

bool is_directory() const {
#ifdef _WIN32
const int ret = GetFileAttributesW(wpath_.c_str());
return ret & FILE_ATTRIBUTE_DIRECTORY;
#else
struct stat info;
if (stat(path_.c_str(), &info) != 0) {
return false;
}
return (info.st_mode & S_IFDIR) != 0;
#endif // _WIN32
}

bool exists() const {
#ifdef _WIN32
const int ret = GetFileAttributesW(wpath_.c_str());
return ret != INVALID_FILE_ATTRIBUTES;
#else
return std::ifstream(path_).good();
#endif
}

private:
std::string path_;

#ifdef _WIN32
std::wstring wpath_;

std::wstring to_wstring() const {
// If there's nothing to convert, bail early.
if (path_.empty()) {
return {};
}

int codePage = CP_UTF8;
int iSource; // convert to int because Mb2Wc requires it.
SizeTToInt(path_.size(), &iSource);

// Ask how much space we will need.
// In certain codepages, Mb2Wc will "successfully" produce zero characters (like in CP50220, where a SHIFT-IN character
// is consumed but not transformed into anything) without explicitly failing. When it does this, GetLastError will return
// the last error encountered by the last function that actually did have an error.
// This is arguably correct (as the documentation says "The function returns 0 if it does not succeed"). There is a
// difference that we **don't actually care about** between failing and successfully producing zero characters.,
// Anyway: we need to clear the last error so that we can fail out and IGNORE_BAD_GLE after it inevitably succeed-fails.
SetLastError(0);
const auto iTarget = MultiByteToWideChar(codePage, 0, path_.data(), iSource, nullptr, 0);

size_t cchNeeded;
IntToSizeT(iTarget, &cchNeeded);

// Allocate ourselves some space
std::wstring out;
out.resize(cchNeeded);

// Attempt conversion for real.
MultiByteToWideChar(codePage, 0, path_.data(), iSource, out.data(), iTarget);

// Return as a string
return out;
}
#endif // _WIN32
};

} // namespace fs
2 changes: 1 addition & 1 deletion src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_
if (params.search.max_length == 0)
throw std::runtime_error("search max_length is 0");
if (params.search.max_length > model.config_->model.context_length)
throw std::runtime_error("max_length (" + std::to_string(params.search.max_length) + ") cannot be greater than model context_length (" + std::to_string(params.search.max_length) + ")");
throw std::runtime_error("max_length (" + std::to_string(params.search.max_length) + ") cannot be greater than model context_length (" + std::to_string(model.config_->model.context_length) + ")");
if (params.batch_size < 1)
throw std::runtime_error("batch_size must be 1 or greater, is " + std::to_string(params.batch_size));
if (params.vocab_size < 1)
Expand Down
2 changes: 1 addition & 1 deletion src/logging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void SetLogString(std::string_view name, std::string_view value) {
gp_logfile.reset();
else {
fs::path filename{std::string(value)};
gp_logfile = std::make_unique<std::ofstream>(filename);
gp_logfile = std::make_unique<std::ofstream>(filename.open_for_write());
}

if (gp_logfile)
Expand Down
5 changes: 5 additions & 0 deletions src/models/captured_graph_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model,
new_captured_graph->sb_extra_inputs_[extra_input.name] = std::make_unique<StaticBuffer>(allocator_device_, first_dim);
}

// Create the input embeddings if needed
if (!model.config_->model.embedding.filename.empty()) {
new_captured_graph->sb_embeddings_ = std::make_unique<StaticBuffer>(allocator_device_, max_beam_batch_size);
}

new_captured_graph->key_ = std::move(key);

return new_captured_graph;
Expand Down
3 changes: 2 additions & 1 deletion src/models/captured_graph_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ struct CapturedGraphInfo {
std::unique_ptr<Generators::StaticBuffer> sb_position_ids_;
std::unique_ptr<Generators::StaticBuffer> sb_attention_mask_;
std::unordered_map<std::string, std::unique_ptr<Generators::StaticBuffer>> sb_extra_inputs_;
std::unique_ptr<Generators::StaticBuffer> sb_embeddings_;
std::unique_ptr<CapturedGraphKey> key_;

#if USE_DML
Expand All @@ -152,7 +153,7 @@ struct CapturedGraphInfo {
// Generates a unique annotation ID across different captured graph objects. This is necessary because different
// generators could be alive at the same time and run the same batch size but with different static buffers, so
// they need to have different annotation IDs.
int GenerateUniqueAnnotationID(int batch_size) {
int GenerateUniqueAnnotationID(int batch_size) const {
// Keep the upper half (minus 1 for the sign bit) of the bits for the unique ID, and keep the lower half for the batch
// size. This should give us 32,767 values for the index and 65,535 values for the batch size, which is more than enough.
int bit_shift = sizeof(int) * 8 / 2;
Expand Down
23 changes: 5 additions & 18 deletions src/models/decoder_only.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
namespace Generators {
DecoderOnly_Model::DecoderOnly_Model(std::unique_ptr<Config> config, OrtEnv& ort_env)
: Model{std::move(config)} {
session_decoder_ = OrtSession::Create(ort_env, (config_->config_path / config_->model.decoder.filename).c_str(), session_options_.get());
session_decoder_ = OrtSession::Create(ort_env, (config_->config_path / fs::path(config_->model.decoder.filename)).c_str(), session_options_.get());

InitDeviceAllocator(*session_decoder_);
}
Expand All @@ -14,7 +14,7 @@ std::unique_ptr<State> DecoderOnly_Model::CreateState(RoamingArray<int32_t> sequ
}

DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, RoamingArray<int32_t> sequence_lengths_unk, const GeneratorParams& params)
: State{params},
: State{params, model},
model_{model},
captured_graph_info_(model.GetCapturedGraphPool()->ReserveCapturedGraph(model, params)),
position_inputs_{model, *this, sequence_lengths_unk} {
Expand All @@ -26,26 +26,13 @@ DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, RoamingArra
}

RoamingArray<float> DecoderOnly_State::Run(int current_length, RoamingArray<int32_t> next_tokens, RoamingArray<int32_t> next_indices) {
if (first_run_) {
if (params_->use_cuda_graph) {
model_.run_options_->AddConfigEntry("gpu_graph_id", "-1");
}
first_run_ = false;
} else {
if (!first_run_) {
UpdateInputs(next_tokens, next_indices, current_length);
}

State::Run(*model_.session_decoder_, *model_.run_options_);
int batch_size = static_cast<int>(input_ids_.GetShape()[0]);
State::Run(*model_.session_decoder_, *model_.run_options_, batch_size);

// Set the graph id for the following runs.
if (params_->use_cuda_graph) {
int new_batch_size = static_cast<int>(input_ids_.GetShape()[0]);
if (new_batch_size != current_batch_size_) {
current_batch_size_ = new_batch_size;
auto annotation_id = std::to_string(captured_graph_info_->GenerateUniqueAnnotationID(new_batch_size));
model_.run_options_->AddConfigEntry("gpu_graph_id", annotation_id.c_str());
}
}
return logits_.Get();
}

Expand Down
2 changes: 0 additions & 2 deletions src/models/decoder_only.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ struct DecoderOnly_State : State {

const DecoderOnly_Model& model_;
CapturedGraphInfoPtr captured_graph_info_;
bool first_run_{true};
int current_batch_size_{0};

InputIDs input_ids_{model_, *this};
Logits logits_{model_, *this};
Expand Down
23 changes: 18 additions & 5 deletions src/models/embeddings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@ Embeddings::Embeddings(const Model& model, State& state, Embeddings::Mode mode,
// They are never the user provided/requested model inputs/outputs
// So only create the transient output and reuse that ortvalue for subsequent
// steps in the pipeline.
if (mode == Embeddings::Mode::Output)
if (mode == Embeddings::Mode::Output) {
if (state_.GetCapturedGraphInfo()) {
sb_embeddings_ = state_.GetCapturedGraphInfo()->sb_embeddings_.get();
}

embeddings_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
}
}

Embeddings::Embeddings(Embeddings&& other, State& state) : model_{other.model_},
Expand Down Expand Up @@ -51,10 +56,18 @@ void Embeddings::Add() {
}

void Embeddings::UpdateSequenceLength() {
shape_[1] = 1;
if (mode_ == Embeddings::Mode::Output) {
embeddings_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
state_.outputs_[index_] = embeddings_.get();
if (shape_[1] != 1) {
shape_[1] = 1;

if (mode_ == Embeddings::Mode::Output) {
if (!sb_embeddings_) {
embeddings_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
} else {
embeddings_ = sb_embeddings_->CreateTensorOnStaticBuffer(shape_, type_);
}

state_.outputs_[index_] = embeddings_.get();
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/models/embeddings.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ struct Embeddings {

OrtValue* Get() { return embeddings_.get(); }

auto& GetShape() const { return shape_; }

private:
const Model& model_;
State& state_;
Expand All @@ -32,6 +34,7 @@ struct Embeddings {
const std::string name_;
std::unique_ptr<OrtValue> embeddings_;
size_t index_{};
StaticBuffer* sb_embeddings_{};
};

} // namespace Generators
8 changes: 5 additions & 3 deletions src/models/gpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace Generators {

Gpt_Model::Gpt_Model(std::unique_ptr<Config> config, OrtEnv& ort_env)
: Model{std::move(config)} {
session_decoder_ = OrtSession::Create(ort_env, (config_->config_path / config_->model.decoder.filename).c_str(), session_options_.get());
session_decoder_ = OrtSession::Create(ort_env, (config_->config_path / fs::path(config_->model.decoder.filename)).c_str(), session_options_.get());
InitDeviceAllocator(*session_decoder_);
}

Expand All @@ -14,7 +14,7 @@ std::unique_ptr<State> Gpt_Model::CreateState(RoamingArray<int32_t> sequence_len
}

Gpt_State::Gpt_State(const Gpt_Model& model, RoamingArray<int32_t> sequence_lengths_unk, const GeneratorParams& params)
: State{params},
: State{params, model},
model_{model},
position_inputs_{model, *this, sequence_lengths_unk} {
input_ids_.Add();
Expand All @@ -25,13 +25,15 @@ Gpt_State::Gpt_State(const Gpt_Model& model, RoamingArray<int32_t> sequence_leng
}

RoamingArray<float> Gpt_State::Run(int current_length, RoamingArray<int32_t> next_tokens, RoamingArray<int32_t> next_indices) {
int batch_size = static_cast<int>(input_ids_.GetShape()[0]);

if (first_run_) {
first_run_ = false;
} else {
UpdateInputs(next_tokens, next_indices, current_length);
}

State::Run(*model_.session_decoder_, *model_.run_options_);
State::Run(*model_.session_decoder_, *model_.run_options_, batch_size);
return logits_.Get();
}

Expand Down
Loading
Loading