-
Notifications
You must be signed in to change notification settings - Fork 8.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Security Update and Enhancement for run.py #264
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
# Copyright 2024 X.AI Corp. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
@@ -13,60 +13,86 @@ | |
# limitations under the License. | ||
|
||
import logging | ||
import hashlib | ||
|
||
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit | ||
from runners import InferenceRunner, ModelRunner, sample_from_model | ||
|
||
|
||
CKPT_PATH = "./checkpoints/" | ||
CKPT_HASH = "expected_checkpoint_hash" | ||
|
||
|
||
def validate_checkpoint(path, expected_hash): | ||
calculated_hash = hashlib.sha256(open(path, 'rb').read()).hexdigest() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is still left to fix. Calling |
||
if calculated_hash != expected_hash: | ||
raise ValueError("Invalid checkpoint file!") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this error message be improved? It might also be nice to utilize logging There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The key changes:
This makes it clear in the logs when a validation failure happens and provides the expected and actual hashes for diagnostics. Other enhancements could include:
It would make the code longer etc. is this necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please utilize Flake8 as well as some standardized code formatter. I'm noticing many inconsistencies in code you submit. There's no problem that I notice with your usage of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 100% convinced this user is just repeating garbage from an LLM. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ai trying to fix ai |
||
|
||
|
||
def main(): | ||
grok_1_model = LanguageModelConfig( | ||
vocab_size=128 * 1024, | ||
pad_token=0, | ||
eos_token=2, | ||
sequence_len=8192, | ||
embedding_init_scale=1.0, | ||
output_multiplier_scale=0.5773502691896257, | ||
embedding_multiplier_scale=78.38367176906169, | ||
model=TransformerConfig( | ||
emb_size=48 * 128, | ||
widening_factor=8, | ||
key_size=128, | ||
num_q_heads=48, | ||
num_kv_heads=8, | ||
num_layers=64, | ||
attn_output_multiplier=0.08838834764831845, | ||
shard_activations=True, | ||
# MoE. | ||
num_experts=8, | ||
num_selected_experts=2, | ||
# Activation sharding. | ||
data_axis="data", | ||
model_axis="model", | ||
), | ||
) | ||
inference_runner = InferenceRunner( | ||
pad_sizes=(1024,), | ||
runner=ModelRunner( | ||
model=grok_1_model, | ||
bs_per_device=0.125, | ||
checkpoint_path=CKPT_PATH, | ||
), | ||
name="local", | ||
load=CKPT_PATH, | ||
tokenizer_path="./tokenizer.model", | ||
local_mesh_config=(1, 8), | ||
between_hosts_config=(1, 1), | ||
) | ||
inference_runner.initialize() | ||
gen = inference_runner.run() | ||
# Validate checkpoint integrity | ||
validate_checkpoint(CKPT_PATH, CKPT_HASH) | ||
MiChaelinzo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
grok_1_model = LanguageModelConfig( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only change here is a dedent from the PEP8 standard 4 space indent. |
||
vocab_size=128 * 1024, | ||
pad_token=0, | ||
eos_token=2, | ||
sequence_len=8192, | ||
embedding_init_scale=1.0, | ||
output_multiplier_scale=0.5773502691896257, | ||
embedding_multiplier_scale=78.38367176906169, | ||
model=TransformerConfig( | ||
emb_size=48 * 128, | ||
widening_factor=8, | ||
key_size=128, | ||
num_q_heads=48, | ||
num_kv_heads=8, | ||
num_layers=64, | ||
attn_output_multiplier=0.08838834764831845, | ||
shard_activations=True, | ||
# MoE. | ||
num_experts=8, | ||
num_selected_experts=2, | ||
# Activation sharding. | ||
data_axis="data", | ||
model_axis="model", | ||
), | ||
) | ||
|
||
inference_runner = InferenceRunner( | ||
pad_sizes=(1024,), | ||
runner=ModelRunner( | ||
model=grok_1_model, | ||
bs_per_device=0.125, | ||
checkpoint_path=CKPT_PATH, | ||
# Limit inference rate | ||
inference_runner.rate_limit = 100 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This appears to reference |
||
), | ||
|
||
name="local", | ||
load=CKPT_PATH, | ||
tokenizer_path="./tokenizer.model", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you were to improve anything, I'd suggest improving how file paths are defined by utilizing |
||
local_mesh_config=(1, 8), | ||
between_hosts_config=(1, 1), | ||
) | ||
|
||
inference_runner.initialize() | ||
|
||
gen = inference_runner.run() | ||
|
||
inp = "The answer to life the universe and everything is of course" | ||
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) | ||
inp = "The answer to life the universe and everything is of course" | ||
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) | ||
|
||
# Add authentication | ||
@app.route("/inference") | ||
@auth.login_required | ||
MiChaelinzo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def inference(): | ||
... | ||
|
||
gen = inference_runner.run() | ||
|
||
# Rest of inference code | ||
|
||
if __name__ == "__main__": | ||
logging.basicConfig(level=logging.INFO) | ||
main() | ||
logging.basicConfig(level=logging.INFO) | ||
main() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2 space indent is not standard. Please view PEP8 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm using 1 space, and you should comment that to the original repo, you're making our lives very complicated enough with your reviews that doesn't make any sense at all! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This change is not accepted and requires fixing. 1 space indent is not standard and worsens readability and consistently in all code affected. Not accepted. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Overall, a complete waste of a PR. Nothing of value was added. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use a context manager
with
for opening and reading the given path. It might also be in our best interest to utilize type hints in the function signature.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The key changes:
Added type hints for the path (Text) and expected_hash (Text) parameters.
Opened the file using a with statement, which automatically closes it when done.
Stored the file contents in a variable called 'contents' to avoid re-reading the file.
Passed the contents variable to hashlib.sha256 rather than the file object.
It seems we would need to import text from typing is this necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the
Text
import is superfluous and could just as easily be replaced withstr
without importing any extra type hints.