Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 01697c6

Browse files
committed
Initial version of reverse mode autodiff
1 parent 57d9b0e commit 01697c6

File tree

11 files changed

+596
-87
lines changed

11 files changed

+596
-87
lines changed

include/tc/core/autodiff.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include "tc/lang/tree.h"
19+
20+
#include <ostream>
21+
22+
namespace tc {
23+
24+
std::string differentiate(const std::string& source);
25+
26+
} // namespace tc

include/tc/lang/sema.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ static inline TreeRef match_types(TreeRef a, TreeRef b) {
155155
/// - replace TK_APPLY with TK_BUILT_IN for built in functions
156156
/// - checks that all variables are defined, and creates index/reduction
157157
/// variable objects.
158+
// - replaces augumented assignments that have no reduction variables
159+
// with regular assignents
158160
struct Sema {
159161
std::unordered_map<TreeRef, TreeRef> expr_to_type;
160162

include/tc/lang/tree_views.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,19 @@ struct ListViewIterator {
125125
bool operator!=(const ListViewIterator& rhs) const {
126126
return it != rhs.it;
127127
}
128+
bool operator==(const ListViewIterator& rhs) const {
129+
return it == rhs.it;
130+
}
128131
T operator*() const {
129132
return T(*it);
130133
}
131-
void operator++() {
134+
ListViewIterator& operator++() {
132135
++it;
136+
return *this;
133137
}
134-
void operator--() {
138+
ListViewIterator& operator--() {
135139
--it;
140+
return *this;
136141
}
137142

138143
private:

src/core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_library(
33

44
SHARED
55

6+
autodiff.cc
67
flags.cc
78
mapping_options.cc
89
mapping_options_cpp_printer.cc

0 commit comments

Comments
 (0)