Skip to content

adding run-make test to autodiff #142444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,10 @@ struct LLVMRustSanitizerOptions {
#ifdef ENZYME
extern "C" void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB,
/* augmentPassBuilder */ bool);

extern "C" {
extern llvm::cl::opt<std::string> EnzymeFunctionToAnalyze;
}
#endif

extern "C" LLVMRustResult LLVMRustOptimize(
Expand Down Expand Up @@ -1069,6 +1073,15 @@ extern "C" LLVMRustResult LLVMRustOptimize(
return LLVMRustResult::Failure;
}

// Check if PrintTAFn was used and add type analysis pass if needed
if (!EnzymeFunctionToAnalyze.empty()) {
if (auto Err = PB.parsePassPipeline(MPM, "print-type-analysis")) {
std::string ErrMsg = toString(std::move(Err));
LLVMRustSetLastError(ErrMsg.c_str());
return LLVMRustResult::Failure;
}
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this changes are for fixing PrintTAFn flag, which was not working, should i create new PR for this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for it, just make it it's own commit and that the history is clean.
E.g. end up with a first commit fixing the Wrapper, and a second commit introducing the tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I will clean the history after I am done writing tests at the end

if (PrintAfterEnzyme) {
// Handle the Rust flag `-Zautodiff=PrintModAfter`.
std::string Banner = "Module after EnzymeNewPM";
Expand Down
13 changes: 13 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Autodiff Type-Trees Type Analysis Tests

This directory contains run-make tests for the autodiff type-trees type analysis functionality. These tests verify that the autodiff compiler correctly analyzes and tracks type information for different Rust types during automatic differentiation.

## What These Tests Do

Each test compiles a simple Rust function with the `#[autodiff_reverse]` attribute and verifies that the compiler:

1. **Correctly identifies type information** in the generated LLVM IR
2. **Tracks type annotations** for variables and operations
3. **Preserves type context** through the autodiff transformation process

The tests capture the stdout from the autodiff compiler (which contains type analysis information) and verify it matches expected patterns using FileCheck.
26 changes: 26 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/array/array.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 4, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 8, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Float@float}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 4, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 8, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
21 changes: 21 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/array/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//@ needs-autodiff
#![feature(autodiff)]

use std::autodiff::autodiff_reverse;

#[autodiff_reverse(d_square, Duplicated, Active)]
#[no_mangle]
fn callee(x: &[f32; 3]) -> f32 {
x[0] * x[0] + x[1] * x[1] + x[2] * x[2]
}

fn main() {
let x = [1.0f32, 2.0, 3.0];
let mut df_dx = [0.0f32; 3];
let out = callee(&x);
let out_ = d_square(&x, &mut df_dx, 1.0);
assert_eq!(out, out_);
assert_eq!(2.0, df_dx[0]);
assert_eq!(4.0, df_dx[1]);
assert_eq!(6.0, df_dx[2]);
}
25 changes: 25 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/array/rmake.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use std::fs;

use run_make_support::{llvm_filecheck, rfs, rustc};

fn main() {
// Compile the Rust file with the required flags, capturing both stdout and stderr
let output = rustc()
.input("array.rs")
.arg("-Zautodiff=Enable,PrintTAFn=callee")
.arg("-Zautodiff=NoPostopt")
.arg("-Copt-level=3")
.arg("-Clto=fat")
.arg("-g")
.run();

let stdout = output.stdout_utf8();
let stderr = output.stderr_utf8();

// Write the outputs to files
fs::write("array.stdout", stdout).unwrap();
fs::write("array.stderr", stderr).unwrap();

// Run FileCheck on the stdout using the check file
llvm_filecheck().patterns("array.check").stdin_buf(rfs::read("array.stdout")).run();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ 0.000000e+00, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw [2 x [2 x float]], ptr %{{[0-9]+}}, i64 %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw [2 x float], ptr %{{[0-9]+}}, i64 %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw float, ptr %{{[0-9]+}}, i64 %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Float@float}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ 0.000000e+00, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw [2 x [2 x float]], ptr %{{[0-9]+}}, i64 %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw [2 x float], ptr %{{[0-9]+}}, i64 %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw float, ptr %{{[0-9]+}}, i64 %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//@ needs-autodiff
#![feature(autodiff)]

use std::autodiff::autodiff_reverse;

#[autodiff_reverse(d_square, Duplicated, Active)]
#[no_mangle]
fn callee(x: &[[[f32; 2]; 2]; 2]) -> f32 {
let mut sum = 0.0;
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
sum += x[i][j][k] * x[i][j][k];
}
}
}
sum
}

fn main() {
let x = [[[1.0f32, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]];
let mut df_dx = [[[0.0f32; 2]; 2]; 2];
let out = callee(&x);
let out_ = d_square(&x, &mut df_dx, 1.0);
assert_eq!(out, out_);
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
assert_eq!(df_dx[i][j][k], 2.0 * x[i][j][k]);
}
}
}
}
25 changes: 25 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/array3d/rmake.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use std::fs;

use run_make_support::{llvm_filecheck, rfs, rustc};

fn main() {
// Compile the Rust file with the required flags, capturing both stdout and stderr
let output = rustc()
.input("array3d.rs")
.arg("-Zautodiff=Enable,PrintTAFn=callee")
.arg("-Zautodiff=NoPostopt")
.arg("-Copt-level=3")
.arg("-Clto=fat")
.arg("-g")
.run();

let stdout = output.stdout_utf8();
let stderr = output.stderr_utf8();

// Write the outputs to files
fs::write("array3d.stdout", stdout).unwrap();
fs::write("array3d.stderr", stderr).unwrap();

// Run FileCheck on the stdout using the check file
llvm_filecheck().patterns("array3d.check").stdin_buf(rfs::read("array3d.stdout")).run();
}
12 changes: 12 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/box/box.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !{{[0-9]+}}, !align !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Float@float}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !{{[0-9]+}}, !align !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
19 changes: 19 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/box/box.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//@ needs-autodiff
#![feature(autodiff)]

use std::autodiff::autodiff_reverse;

#[autodiff_reverse(d_square, Duplicated, Active)]
#[no_mangle]
fn callee(x: &Box<f32>) -> f32 {
**x * **x
}

fn main() {
let x = Box::new(7.0f32);
let mut df_dx = Box::new(0.0f32);
let out = callee(&x);
let out_ = d_square(&x, &mut df_dx, 1.0);
assert_eq!(out, out_);
assert_eq!(14.0, *df_dx);
}
25 changes: 25 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/box/rmake.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use std::fs;

use run_make_support::{llvm_filecheck, rfs, rustc};

fn main() {
// Compile the Rust file with the required flags, capturing both stdout and stderr
let output = rustc()
.input("box.rs")
.arg("-Zautodiff=Enable,PrintTAFn=callee")
.arg("-Zautodiff=NoPostopt")
.arg("-Copt-level=3")
.arg("-Clto=fat")
.arg("-g")
.run();

let stdout = output.stdout_utf8();
let stderr = output.stderr_utf8();

// Write the outputs to files
fs::write("box.stdout", stdout).unwrap();
fs::write("box.stderr", stderr).unwrap();

// Run FileCheck on the stdout using the check file
llvm_filecheck().patterns("box.check").stdin_buf(rfs::read("box.stdout")).run();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Float@float}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//@ needs-autodiff
#![feature(autodiff)]

use std::autodiff::autodiff_reverse;

#[autodiff_reverse(d_square, Duplicated, Active)]
#[no_mangle]
fn callee(x: *const f32) -> f32 {
unsafe { *x * *x }
}

fn main() {
let x: f32 = 7.0;
let out = callee(&x as *const f32);
let mut df_dx: f32 = 0.0;
let out_ = d_square(&x as *const f32, &mut df_dx, 1.0);
assert_eq!(out, out_);
assert_eq!(14.0, df_dx);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use std::fs;

use run_make_support::{llvm_filecheck, rfs, rustc};

fn main() {
// Compile the Rust file with the required flags, capturing both stdout and stderr
let output = rustc()
.input("const_pointer.rs")
.arg("-Zautodiff=Enable,PrintTAFn=callee")
.arg("-Zautodiff=NoPostopt")
.arg("-Copt-level=3")
.arg("-Clto=fat")
.arg("-g")
.run();

let stdout = output.stdout_utf8();
let stderr = output.stderr_utf8();

// Write the outputs to files
fs::write("const_pointer.stdout", stdout).unwrap();
fs::write("const_pointer.stderr", stderr).unwrap();

// Run FileCheck on the stdout using the check file
llvm_filecheck()
.patterns("const_pointer.check")
.stdin_buf(rfs::read("const_pointer.stdout"))
.run();
}
12 changes: 12 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/f32/f32.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}

// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Float@float}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}

// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
20 changes: 20 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/f32/f32.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//@ needs-autodiff
#![feature(autodiff)]

use std::autodiff::autodiff_reverse;

#[autodiff_reverse(d_square, Duplicated, Active)]
#[no_mangle]
fn callee(x: &f32) -> f32 {
*x * *x
}

fn main() {
let x: f32 = 7.0;
let mut df_dx: f32 = 0.0;
let out = callee(&x);
let out_ = d_square(&x, &mut df_dx, 1.0);

assert_eq!(out, out_);
assert_eq!(14.0, df_dx);
}
Loading
Loading