Skip to content

Commit 147f755

Browse files
gmagogsfmfacebook-github-bot
authored andcommitted
Add SourceView which doesn't own source text as base class of Source (pytorch#65309)
Summary: This would save the cost copying text from stack to heap in some cases (like parsing function schema during loading phase of libtorch.so) Pull Request resolved: pytorch#65309 Reviewed By: swolchok Differential Revision: D31060315 Pulled By: gmagogsfm fbshipit-source-id: 0caf7a688b40df52bb4388c5191d1a42351d6f1a
1 parent bff64e8 commit 147f755

21 files changed

+174
-90
lines changed

test/cpp/jit/test_class_import.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
namespace torch {
1010
namespace jit {
1111

12-
static const auto classSrcs1 = R"JIT(
12+
static constexpr c10::string_view classSrcs1 = R"JIT(
1313
class FooNestedTest:
1414
def __init__(self, y):
1515
self.y = y
@@ -26,7 +26,7 @@ class FooTest:
2626
self.x = self.class_attr.y + self.class_attr2.y
2727
)JIT";
2828

29-
static const auto classSrcs2 = R"JIT(
29+
static constexpr c10::string_view classSrcs2 = R"JIT(
3030
class FooTest:
3131
def __init__(self, x):
3232
self.dx = x
@@ -134,7 +134,7 @@ TEST(ClassImportTest, ClassDerive) {
134134
ASSERT_TRUE(newCls2->findMethod(method->name()));
135135
}
136136

137-
static const auto torchbindSrc = R"JIT(
137+
static constexpr c10::string_view torchbindSrc = R"JIT(
138138
class FooBar1234(Module):
139139
__parameters__ = []
140140
f : __torch__.torch.classes._TorchScriptTesting._StackString

test/cpp/jit/test_class_parser.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace torch {
77
namespace jit {
8-
const auto testSource = R"JIT(
8+
constexpr c10::string_view testSource = R"JIT(
99
class FooTest:
1010
def __init__(self, x):
1111
self.x = x

test/cpp/jit/test_interface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ def one(self, x: Tensor, y: Tensor) -> Tensor:
1818
def forward(self, x: Tensor) -> Tensor:
1919
return x
2020
)JIT"};
21-
static const auto parentForward = R"JIT(
21+
static const std::string parentForward = R"JIT(
2222
def forward(self, x: Tensor) -> Tensor:
2323
return self.subMod.forward(x)
2424
)JIT";
2525

26-
static const auto moduleInterfaceSrc = R"JIT(
26+
static constexpr c10::string_view moduleInterfaceSrc = R"JIT(
2727
class OneForward(ModuleInterface):
2828
def one(self, x: Tensor, y: Tensor) -> Tensor:
2929
pass

test/cpp/jit/test_module_api.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
namespace torch {
1414
namespace jit {
1515

16-
static const auto moduleInterfaceSrc = R"JIT(
16+
static constexpr c10::string_view moduleInterfaceSrc = R"JIT(
1717
class OneInterface(ModuleInterface):
1818
def one(self, x: Tensor, y: Tensor) -> Tensor:
1919
pass
@@ -27,7 +27,7 @@ def forward(self, x: Tensor) -> Tensor:
2727
return self.attr + x
2828
)JIT"};
2929

30-
static const auto parentForward = R"JIT(
30+
static const std::string parentForward = R"JIT(
3131
def forward(self, x: Tensor) -> Tensor:
3232
return self.subMod1.one(x, x) + self.subMod2.one(x, x)
3333
)JIT";

torch/csrc/jit/frontend/error_report.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ struct TORCH_API ErrorReport : public std::exception {
3939
friend const ErrorReport& operator<<(const ErrorReport& e, const T& t);
4040

4141
mutable std::stringstream ss;
42-
SourceRange context;
42+
OwnedSourceRange context;
4343
mutable std::string the_message;
4444
std::vector<Call> error_stack;
4545
};

torch/csrc/jit/frontend/function_schema_parser.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace jit {
2727
namespace {
2828
struct SchemaParser {
2929
SchemaParser(const std::string& str)
30-
: L(std::make_shared<Source>(str)),
30+
: L(std::make_shared<SourceView>(c10::string_view(str))),
3131
type_parser(L, /*parse_complete_tensor_types*/ false) {}
3232

3333
either<OperatorName, FunctionSchema> parseDeclaration() {

torch/csrc/jit/frontend/lexer.h

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ struct TORCH_API SharedParserData {
185185
// find the longest match of str.substring(pos) against a token, return true
186186
// if successful filling in kind, start,and len
187187
bool match(
188-
const std::string& str,
188+
c10::string_view str,
189189
size_t pos,
190190
bool continuation, // are we inside a scope where newlines don't count
191191
// (e.g. inside parens)
@@ -300,15 +300,15 @@ struct TORCH_API SharedParserData {
300300
// 1. skip whitespace
301301
// 2. handle comment or newline
302302
//
303-
bool isNumber(const std::string& str, size_t start, size_t* len) {
303+
bool isNumber(c10::string_view str, size_t start, size_t* len) {
304304
char first = str[start];
305305
// strtod allows numbers to start with + or - or nan or inf
306306
// http://en.cppreference.com/w/cpp/string/byte/strtof
307307
// but we want only the number part, otherwise 1+3 will turn into two
308308
// adjacent numbers in the lexer
309309
if (first == '-' || first == '+' || isalpha(first))
310310
return false;
311-
const char* startptr = str.c_str() + start;
311+
const char* startptr = str.data() + start;
312312
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
313313
char* endptr;
314314
torch::jit::strtod_c(startptr, &endptr);
@@ -321,7 +321,7 @@ struct TORCH_API SharedParserData {
321321
return *len > 0;
322322
}
323323

324-
bool isCharCount(char c, const std::string& str, size_t start, int len) {
324+
bool isCharCount(char c, c10::string_view str, size_t start, int len) {
325325
// count checks from [start, start + len)
326326
return start + len <= str.size() &&
327327
std::count(str.begin() + start, str.begin() + start + len, c) == len;
@@ -331,7 +331,7 @@ struct TORCH_API SharedParserData {
331331
// strings can be enclosed with 1 or 3 single or double quotes
332332
// if enclosed with 3 quotes newlines are valid
333333
// as elsewhere, backslash and new line should be ignored
334-
bool isString(const std::string& str, size_t start, size_t* len) {
334+
bool isString(c10::string_view str, size_t start, size_t* len) {
335335
char quote = str[start];
336336
if (quote != '\"' && quote != '\'')
337337
return false;
@@ -362,9 +362,8 @@ struct TORCH_API SharedParserData {
362362
bool isblank(int n) {
363363
return isspace(n) && n != '\n';
364364
}
365-
366365
// Make an exception ignoring comments for type annotation comments
367-
bool isTypeComment(const std::string& str, size_t pos) {
366+
bool isTypeComment(c10::string_view str, size_t pos) {
368367
const std::string type_string = "# type:";
369368
if (str.size() < pos + type_string.length()) {
370369
return false;
@@ -391,7 +390,7 @@ struct Token {
391390
};
392391

393392
struct Lexer {
394-
explicit Lexer(std::shared_ptr<Source> source)
393+
explicit Lexer(std::shared_ptr<SourceView> source)
395394
: source(std::move(source)),
396395
pos(0),
397396
nesting(0),
@@ -532,7 +531,7 @@ struct Lexer {
532531
return t;
533532
}
534533

535-
std::shared_ptr<Source> source;
534+
std::shared_ptr<SourceView> source;
536535
size_t pos;
537536
size_t nesting; // depth of ( [ { nesting...
538537
std::vector<int> indent_stack; // stack of indentation level of blocks

torch/csrc/jit/frontend/parser.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Decl mergeTypesFromTypeComment(
4646
}
4747

4848
struct ParserImpl {
49-
explicit ParserImpl(const std::shared_ptr<Source>& source)
49+
explicit ParserImpl(const std::shared_ptr<SourceView>& source)
5050
: L(source), shared(sharedParserData()) {}
5151

5252
Ident parseIdent() {
@@ -801,7 +801,7 @@ struct ParserImpl {
801801
SharedParserData& shared;
802802
};
803803

804-
Parser::Parser(const std::shared_ptr<Source>& src)
804+
Parser::Parser(const std::shared_ptr<SourceView>& src)
805805
: pImpl(new ParserImpl(src)) {}
806806

807807
Parser::~Parser() = default;

torch/csrc/jit/frontend/parser.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ TORCH_API Decl mergeTypesFromTypeComment(
1717
bool is_method);
1818

1919
struct TORCH_API Parser {
20-
explicit Parser(const std::shared_ptr<Source>& src);
20+
explicit Parser(const std::shared_ptr<SourceView>& src);
2121
TreeRef parseFunction(bool is_method);
2222
TreeRef parseClass();
2323
Decl parseTypeComment();

torch/csrc/jit/frontend/source_range.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ size_t SourceRangeHasher::operator()(const torch::jit::SourceRange& key) const {
1010
std::hash<size_t>()(key.start()) ^ std::hash<size_t>()(key.end()));
1111
}
1212

13-
c10::optional<SourceRange> Source::findSourceRangeThatGenerated(
13+
c10::optional<SourceRange> SourceView::findSourceRangeThatGenerated(
1414
const SourceRange& range) {
1515
if (!gen_ranges_) {
1616
return c10::nullopt;
@@ -69,11 +69,11 @@ C10_EXPORT void SourceRange::print_with_context(
6969
bool highlight,
7070
const std::string& funcname) const {
7171
// This is an empty SourceRange, used as a sentinel value.
72-
if (!source_) {
72+
if (!source_view_) {
7373
return;
7474
}
7575

76-
const std::string& str = source_->text();
76+
c10::string_view str = source_view_->text();
7777
if (size() == str.size()) {
7878
// this is just the entire file, not a subset, so print it out.
7979
// primarily used to print out python stack traces

0 commit comments

Comments
 (0)