Skip to content

Commit fffc4fc

Browse files
committed
Auto merge of #151395 - Zalathar:rollup-8gANGZS, r=Zalathar
Rollup of 8 pull requests Successful merges: - #149587 (coverage: Sort the expansion tree to help choose a single BCB for child expansions) - #150071 (Add dist step for Enzyme) - #150288 (Add scalar support for offload) - #151091 (Add new "hide deprecated items" setting in rustdoc) - #151255 (rustdoc: Fix ICE when deprecated note is not resolved on the correct `DefId`) - #151375 (Fix terminal width dependent tests) - #151384 (add basic `TokenStream` api tests) - #151391 (rustc-dev-guide subtree update) r? @ghost
2 parents 7981818 + 7c33b49 commit fffc4fc

63 files changed

Lines changed: 1125 additions & 255 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use rustc_codegen_ssa::back::write::{
1313
TargetMachineFactoryConfig, TargetMachineFactoryFn,
1414
};
1515
use rustc_codegen_ssa::base::wants_wasm_eh;
16+
use rustc_codegen_ssa::common::TypeKind;
1617
use rustc_codegen_ssa::traits::*;
1718
use rustc_codegen_ssa::{CompiledModule, ModuleCodegen, ModuleKind};
1819
use rustc_data_structures::profiling::SelfProfilerRef;
@@ -33,6 +34,8 @@ use crate::back::owned_target_machine::OwnedTargetMachine;
3334
use crate::back::profiling::{
3435
LlvmSelfProfiler, selfprofile_after_pass_callback, selfprofile_before_pass_callback,
3536
};
37+
use crate::builder::SBuilder;
38+
use crate::builder::gpu_offload::scalar_width;
3639
use crate::common::AsCCharPtr;
3740
use crate::errors::{
3841
CopyBitcode, FromLlvmDiag, FromLlvmOptimizationDiag, LlvmError, UnknownCompression,
@@ -669,7 +672,17 @@ pub(crate) unsafe fn llvm_optimize(
669672
// Create the new parameter list, with ptr as the first argument
670673
let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
671674
new_param_types.push(cx.type_ptr());
672-
new_param_types.extend(old_param_types);
675+
676+
// This relies on undocumented LLVM knowledge that scalars must be passed as i64
677+
for &old_ty in &old_param_types {
678+
let new_ty = match cx.type_kind(old_ty) {
679+
TypeKind::Half | TypeKind::Float | TypeKind::Double | TypeKind::Integer => {
680+
cx.type_i64()
681+
}
682+
_ => old_ty,
683+
};
684+
new_param_types.push(new_ty);
685+
}
673686

674687
// Create the new function type
675688
let ret_ty = unsafe { llvm::LLVMGetReturnType(old_fn_ty) };
@@ -682,10 +695,33 @@ pub(crate) unsafe fn llvm_optimize(
682695
let a0 = llvm::get_param(new_fn, 0);
683696
llvm::set_value_name(a0, CString::new("dyn_ptr").unwrap().as_bytes());
684697

698+
let bb = SBuilder::append_block(cx, new_fn, "entry");
699+
let mut builder = SBuilder::build(cx, bb);
700+
701+
let mut old_args_rebuilt = Vec::with_capacity(old_param_types.len());
702+
703+
for (i, &old_ty) in old_param_types.iter().enumerate() {
704+
let new_arg = llvm::get_param(new_fn, (i + 1) as u32);
705+
706+
let rebuilt = match cx.type_kind(old_ty) {
707+
TypeKind::Half | TypeKind::Float | TypeKind::Double | TypeKind::Integer => {
708+
let num_bits = scalar_width(cx, old_ty);
709+
710+
let trunc = builder.trunc(new_arg, cx.type_ix(num_bits));
711+
builder.bitcast(trunc, old_ty)
712+
}
713+
_ => new_arg,
714+
};
715+
716+
old_args_rebuilt.push(rebuilt);
717+
}
718+
719+
builder.ret_void();
720+
685721
// Here we map the old arguments to the new arguments, with an offset of 1 to make sure
686722
// that we don't use the newly added `%dyn_ptr`.
687723
unsafe {
688-
llvm::LLVMRustOffloadMapper(old_fn, new_fn);
724+
llvm::LLVMRustOffloadMapper(old_fn, new_fn, old_args_rebuilt.as_ptr());
689725
}
690726

691727
llvm::set_linkage(new_fn, llvm::get_linkage(old_fn));

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,21 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
9797
GenericBuilder { llbuilder, cx: scx }
9898
}
9999

100+
pub(crate) fn append_block(
101+
cx: &'a GenericCx<'ll, CX>,
102+
llfn: &'ll Value,
103+
name: &str,
104+
) -> &'ll BasicBlock {
105+
unsafe {
106+
let name = SmallCStr::new(name);
107+
llvm::LLVMAppendBasicBlockInContext(cx.llcx(), llfn, name.as_ptr())
108+
}
109+
}
110+
111+
pub(crate) fn trunc(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
112+
unsafe { llvm::LLVMBuildTrunc(self.llbuilder, val, dest_ty, UNNAMED) }
113+
}
114+
100115
pub(crate) fn bitcast(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
101116
unsafe { llvm::LLVMBuildBitCast(self.llbuilder, val, dest_ty, UNNAMED) }
102117
}

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::ffi::CString;
22

33
use llvm::Linkage::*;
44
use rustc_abi::Align;
5+
use rustc_codegen_ssa::common::TypeKind;
56
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
67
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
78
use rustc_middle::bug;
@@ -361,7 +362,6 @@ pub(crate) fn add_global<'ll>(
361362
pub(crate) fn gen_define_handling<'ll>(
362363
cx: &CodegenCx<'ll, '_>,
363364
metadata: &[OffloadMetadata],
364-
types: &[&'ll Type],
365365
symbol: String,
366366
offload_globals: &OffloadGlobals<'ll>,
367367
) -> OffloadKernelGlobals<'ll> {
@@ -371,25 +371,18 @@ pub(crate) fn gen_define_handling<'ll>(
371371

372372
let offload_entry_ty = offload_globals.offload_entry_ty;
373373

374-
// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
375-
// reference) types.
376-
let ptr_meta = types.iter().zip(metadata).filter_map(|(&x, meta)| match cx.type_kind(x) {
377-
rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta),
378-
_ => None,
379-
});
380-
381374
// FIXME(Sa4dUs): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary
382-
let (ptr_sizes, ptr_transfer): (Vec<_>, Vec<_>) =
383-
ptr_meta.map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip();
375+
let (sizes, transfer): (Vec<_>, Vec<_>) =
376+
metadata.iter().map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip();
384377

385-
let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &ptr_sizes);
378+
let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &sizes);
386379
// Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
387380
// or both to and from the gpu (=3). Other values shouldn't affect us for now.
388381
// A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
389382
// will be 2. For now, everything is 3, until we have our frontend set up.
390383
// 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later).
391384
let memtransfer_types =
392-
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &ptr_transfer);
385+
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &transfer);
393386

394387
// Next: For each function, generate these three entries. A weak constant,
395388
// the llvm.rodata entry name, and the llvm_offload_entries value
@@ -445,13 +438,25 @@ fn declare_offload_fn<'ll>(
445438
)
446439
}
447440

441+
pub(crate) fn scalar_width<'ll>(cx: &'ll SimpleCx<'_>, ty: &'ll Type) -> u64 {
442+
match cx.type_kind(ty) {
443+
TypeKind::Half
444+
| TypeKind::Float
445+
| TypeKind::Double
446+
| TypeKind::X86_FP80
447+
| TypeKind::FP128
448+
| TypeKind::PPC_FP128 => cx.float_width(ty) as u64,
449+
TypeKind::Integer => cx.int_width(ty),
450+
other => bug!("scalar_width was called on a non scalar type {other:?}"),
451+
}
452+
}
453+
448454
// For each kernel *call*, we now use some of our previous declared globals to move data to and from
449455
// the gpu. For now, we only handle the data transfer part of it.
450456
// If two consecutive kernels use the same memory, we still move it to the host and back to the gpu.
451457
// Since in our frontend users (by default) don't have to specify data transfer, this is something
452-
// we should optimize in the future! We also assume that everything should be copied back and forth,
453-
// but sometimes we can directly zero-allocate on the device and only move back, or if something is
454-
// immutable, we might only copy it to the device, but not back.
458+
// we should optimize in the future! In some cases we can directly zero-allocate on the device and
459+
// only move data back, or if something is immutable, we might only copy it to the device.
455460
//
456461
// Current steps:
457462
// 0. Alloca some variables for the following steps
@@ -538,8 +543,34 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
538543
let mut geps = vec![];
539544
let i32_0 = cx.get_const_i32(0);
540545
for &v in args {
541-
let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]);
542-
vals.push(v);
546+
let ty = cx.val_ty(v);
547+
let ty_kind = cx.type_kind(ty);
548+
let (base_val, gep_base) = match ty_kind {
549+
TypeKind::Pointer => (v, v),
550+
TypeKind::Half | TypeKind::Float | TypeKind::Double | TypeKind::Integer => {
551+
// FIXME(Sa4dUs): check for `f128` support, latest NVIDIA cards support it
552+
let num_bits = scalar_width(cx, ty);
553+
554+
let bb = builder.llbb();
555+
unsafe {
556+
llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, builder.llfn());
557+
}
558+
let addr = builder.direct_alloca(cx.type_i64(), Align::EIGHT, "addr");
559+
unsafe {
560+
llvm::LLVMPositionBuilderAtEnd(builder.llbuilder, bb);
561+
}
562+
563+
let cast = builder.bitcast(v, cx.type_ix(num_bits));
564+
let value = builder.zext(cast, cx.type_i64());
565+
builder.store(value, addr, Align::EIGHT);
566+
(value, addr)
567+
}
568+
other => bug!("offload does not support {other:?}"),
569+
};
570+
571+
let gep = builder.inbounds_gep(cx.type_f32(), gep_base, &[i32_0]);
572+
573+
vals.push(base_val);
543574
geps.push(gep);
544575
}
545576

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,7 +1394,8 @@ fn codegen_offload<'ll, 'tcx>(
13941394
let args = get_args_from_tuple(bx, args[3], fn_target);
13951395
let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target, LOCAL_CRATE);
13961396

1397-
let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder();
1397+
let sig = tcx.fn_sig(fn_target.def_id()).skip_binder();
1398+
let sig = tcx.instantiate_bound_regions_with_erased(sig);
13981399
let inputs = sig.inputs();
13991400

14001401
let metadata = inputs.iter().map(|ty| OffloadMetadata::from_ty(tcx, *ty)).collect::<Vec<_>>();
@@ -1409,7 +1410,7 @@ fn codegen_offload<'ll, 'tcx>(
14091410
return;
14101411
}
14111412
};
1412-
let offload_data = gen_define_handling(&cx, &metadata, &types, target_symbol, offload_globals);
1413+
let offload_data = gen_define_handling(&cx, &metadata, target_symbol, offload_globals);
14131414
gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals, &offload_dims);
14141415
}
14151416

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,7 +1675,11 @@ mod Offload {
16751675
_M: &'a Module,
16761676
_host_out: *const c_char,
16771677
) -> bool;
1678-
pub(crate) fn LLVMRustOffloadMapper<'a>(OldFn: &'a Value, NewFn: &'a Value);
1678+
pub(crate) fn LLVMRustOffloadMapper<'a>(
1679+
OldFn: &'a Value,
1680+
NewFn: &'a Value,
1681+
RebuiltArgs: *const &Value,
1682+
);
16791683
}
16801684
}
16811685

@@ -1702,7 +1706,11 @@ mod Offload_fallback {
17021706
unimplemented!("This rustc version was not built with LLVM Offload support!");
17031707
}
17041708
#[allow(unused_unsafe)]
1705-
pub(crate) unsafe fn LLVMRustOffloadMapper<'a>(_OldFn: &'a Value, _NewFn: &'a Value) {
1709+
pub(crate) unsafe fn LLVMRustOffloadMapper<'a>(
1710+
_OldFn: &'a Value,
1711+
_NewFn: &'a Value,
1712+
_RebuiltArgs: *const &Value,
1713+
) {
17061714
unimplemented!("This rustc version was not built with LLVM Offload support!");
17071715
}
17081716
}

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,12 @@ extern "C" bool LLVMRustOffloadEmbedBufferInModule(LLVMModuleRef HostM,
223223
return true;
224224
}
225225

226-
extern "C" void LLVMRustOffloadMapper(LLVMValueRef OldFn, LLVMValueRef NewFn) {
226+
// Clone OldFn into NewFn, remapping its arguments to RebuiltArgs.
227+
// Each arg of OldFn is replaced with the corresponding value in RebuiltArgs.
228+
// For scalars, RebuiltArgs contains the value cast and/or truncated to the
229+
// original type.
230+
extern "C" void LLVMRustOffloadMapper(LLVMValueRef OldFn, LLVMValueRef NewFn,
231+
const LLVMValueRef *RebuiltArgs) {
227232
llvm::Function *oldFn = llvm::unwrap<llvm::Function>(OldFn);
228233
llvm::Function *newFn = llvm::unwrap<llvm::Function>(NewFn);
229234

@@ -232,15 +237,25 @@ extern "C" void LLVMRustOffloadMapper(LLVMValueRef OldFn, LLVMValueRef NewFn) {
232237
llvm::ValueToValueMapTy vmap;
233238
auto newArgIt = newFn->arg_begin();
234239
newArgIt->setName("dyn_ptr");
235-
++newArgIt; // skip %dyn_ptr
240+
241+
unsigned i = 0;
236242
for (auto &oldArg : oldFn->args()) {
237-
vmap[&oldArg] = &*newArgIt++;
243+
vmap[&oldArg] = unwrap<Value>(RebuiltArgs[i++]);
238244
}
239245

240246
llvm::SmallVector<llvm::ReturnInst *, 8> returns;
241247
llvm::CloneFunctionInto(newFn, oldFn, vmap,
242248
llvm::CloneFunctionChangeType::LocalChangesOnly,
243249
returns);
250+
251+
BasicBlock &entry = newFn->getEntryBlock();
252+
BasicBlock &clonedEntry = *std::next(newFn->begin());
253+
254+
if (entry.getTerminator())
255+
entry.getTerminator()->eraseFromParent();
256+
257+
IRBuilder<> B(&entry);
258+
B.CreateBr(&clonedEntry);
244259
}
245260
#endif
246261

compiler/rustc_middle/src/ty/offload_meta.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,13 @@ impl MappingFlags {
7878
use rustc_ast::Mutability::*;
7979

8080
match ty.kind() {
81-
ty::Bool
82-
| ty::Char
83-
| ty::Int(_)
84-
| ty::Uint(_)
85-
| ty::Float(_)
86-
| ty::Adt(_, _)
87-
| ty::Tuple(_)
88-
| ty::Array(_, _)
89-
| ty::Alias(_, _)
90-
| ty::Param(_) => MappingFlags::TO,
81+
ty::Bool | ty::Char | ty::Int(_) | ty::Uint(_) | ty::Float(_) => {
82+
MappingFlags::LITERAL | MappingFlags::IMPLICIT
83+
}
84+
85+
ty::Adt(_, _) | ty::Tuple(_) | ty::Array(_, _) | ty::Alias(_, _) | ty::Param(_) => {
86+
MappingFlags::TO
87+
}
9188

9289
ty::RawPtr(_, Not) | ty::Ref(_, _, Not) => MappingFlags::TO,
9390

0 commit comments

Comments
 (0)