Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -412,21 +412,21 @@ String createPTX(KernelCallGraph kernelCallGraph, Object... args){
builder.ptxHeader(major, minor, target, addressSize);
out.append(builder.getTextAndReset());

if (CallGraph.usingModuleOp) {
System.out.println("Using ModuleOp for CudaBackend");
kernelCallGraph.moduleOp.functionTable().forEach((_, funcOp) -> {
CoreOp.FuncOp loweredFunc = OpTk.lower(funcOp);
loweredFunc = transformPTXPtrs(kernelCallGraph.computeContext.accelerator.lookup,loweredFunc, argsMap, usedMathFns);
invokedMethods.append(createFunction(new PTXHATKernelBuilder(addressSize).nl().nl(), loweredFunc, false));
});
} else {
if (CallGraph.noModuleOp) {
System.out.println("NOT using ModuleOp for CudaBackend");
for (KernelCallGraph.KernelReachableResolvedMethodCall k : kernelCallGraph.kernelReachableResolvedStream().toList()) {
CoreOp.FuncOp calledFunc = k.funcOp();
CoreOp.FuncOp loweredFunc = OpTk.lower(calledFunc);
loweredFunc = transformPTXPtrs(kernelCallGraph.computeContext.accelerator.lookup,loweredFunc, argsMap, usedMathFns);
invokedMethods.append(createFunction(new PTXHATKernelBuilder(addressSize).nl().nl(), loweredFunc, false));
}
} else {
System.out.println("Using ModuleOp for CudaBackend");
kernelCallGraph.moduleOp.functionTable().forEach((_, funcOp) -> {
CoreOp.FuncOp loweredFunc = OpTk.lower(funcOp);
loweredFunc = transformPTXPtrs(kernelCallGraph.computeContext.accelerator.lookup,loweredFunc, argsMap, usedMathFns);
invokedMethods.append(createFunction(new PTXHATKernelBuilder(addressSize).nl().nl(), loweredFunc, false));
});
}

lowered = transformPTXPtrs(kernelCallGraph.computeContext.accelerator.lookup,lowered, argsMap, usedMathFns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ public void dispatchKernel(KernelCallGraph kernelCallGraph, NDRange ndRange, Obj
// Here we receive a callgraph from the kernel entrypoint
// The first time we see this we need to convert the kernel entrypoint
// and rechable methods to a form that our mock backend can execute.
if (CallGraph.usingModuleOp) {
System.out.println("Using ModuleOp for MockBackend");
kernelCallGraph.moduleOp.functionTable().forEach((_, funcOp) -> {
});
} else {
if (CallGraph.noModuleOp) {
System.out.println("NOT using ModuleOp for MockBackend");
kernelCallGraph.kernelReachableResolvedStream().forEach(kr -> {

});
} else {
System.out.println("Using ModuleOp for MockBackend");
kernelCallGraph.moduleOp.functionTable().forEach((_, funcOp) -> {
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,21 +233,21 @@ public <T extends C99HATKernelBuilder<T>> String createCode(KernelCallGraph kern
kernelCallGraph.entrypoint.funcOp());

// Sorting by rank ensures we don't need forward declarations
if (CallGraph.usingModuleOp) {
System.out.println("Using ModuleOp for C99FFIBackend");
kernelCallGraph.moduleOp.functionTable()
.forEach((_, funcOp) -> builder
.nl()
.kernelMethod(buildContext,funcOp)
.nl());
} else {
if (CallGraph.noModuleOp) {
System.out.println("NOT using ModuleOp for C99FFIBackend");
kernelCallGraph.kernelReachableResolvedStream().sorted((lhs, rhs) -> rhs.rank - lhs.rank)
.forEach(kernelReachableResolvedMethod ->
builder
.nl()
.kernelMethod(buildContext,kernelReachableResolvedMethod.funcOp())
.nl());
} else {
System.out.println("Using ModuleOp for C99FFIBackend");
kernelCallGraph.moduleOp.functionTable()
.forEach((_, funcOp) -> builder
.nl()
.kernelMethod(buildContext,funcOp)
.nl());
}

builder.nl().kernelEntrypoint(buildContext, args).nl();
Expand Down
214 changes: 214 additions & 0 deletions hat/core/src/main/java/hat/BufferTagger.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
package hat;

import hat.buffer.Buffer;
import hat.ifacemapper.MappableIface;
import jdk.incubator.code.*;
import jdk.incubator.code.analysis.Inliner;
import jdk.incubator.code.analysis.SSA;
import jdk.incubator.code.dialect.core.CoreOp;
import jdk.incubator.code.dialect.java.*;

import java.lang.invoke.MethodHandles;
import java.lang.reflect.Method;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;

public class BufferTagger {
static HashMap<Value, AccessType> accessMap = new HashMap<>();
static HashMap<Value, Value> remappedVals = new HashMap<>(); // maps values to their "root" parameter/value
static HashMap<Block, List<Block.Parameter>> blockParams = new HashMap<>(); // holds block parameters for easy lookup

public enum AccessType {
NA(1),
RO(2),
WO(4),
RW(6),
NOT_BUFFER(0);

public final int value;
AccessType(int i) {
value = i;
}
}

// generates a list of AccessTypes matching the given FuncOp's parameter order
public static ArrayList<AccessType> getAccessList(MethodHandles.Lookup l, CoreOp.FuncOp f) {
CoreOp.FuncOp inlinedFunc = inlineLoop(l, f);
buildAccessMap(l, inlinedFunc);
ArrayList<AccessType> accessList = new ArrayList<>();
for (Block.Parameter p : inlinedFunc.body().entryBlock().parameters()) {
if (accessMap.containsKey(p)) {
accessList.add(accessMap.get(p)); // is an accessed buffer
} else if (getClass(l, p.type()) instanceof Class<?> c && MappableIface.class.isAssignableFrom(c)) {
accessList.add(AccessType.NA); // is a buffer but not accessed
} else {
accessList.add(AccessType.NOT_BUFFER); // is not a buffer
}
}
return accessList;
}

// inlines functions found in FuncOp f until no more inline-able functions are present
public static CoreOp.FuncOp inlineLoop(MethodHandles.Lookup l, CoreOp.FuncOp f) {
CoreOp.FuncOp ssaFunc = SSA.transform(f.transform(OpTransformer.LOWERING_TRANSFORMER));
AtomicBoolean changed = new AtomicBoolean(true);
while (changed.get()) { // loop until no more inline-able functions
changed.set(false);
ssaFunc = ssaFunc.transform((bb, op) -> {
if (op instanceof JavaOp.InvokeOp iop) {
MethodRef methodRef = iop.invokeDescriptor();
Method invokeOpCalledMethod;
try {
invokeOpCalledMethod = methodRef.resolveToMethod(l, iop.invokeKind());
} catch (ReflectiveOperationException _) {
throw new IllegalStateException("Could not resolve invokeOp to method");
}
if (invokeOpCalledMethod instanceof Method method) { // if method isn't a buffer access (is code reflected)
if (Op.ofMethod(method).isPresent()) {
CoreOp.FuncOp inline = Op.ofMethod(method).get(); // method to be inlined
CoreOp.FuncOp ssaInline = SSA.transform(inline.transform(OpTransformer.LOWERING_TRANSFORMER));

Block.Builder exit = Inliner.inline(bb, ssaInline, bb.context().getValues(iop.operands()), (_, v) -> {
if (v != null) bb.context().mapValue(iop.result(), v);
});

if (!exit.parameters().isEmpty()) {
bb.context().mapValue(iop.result(), exit.parameters().getFirst());
}
changed.set(true);
return exit.rebind(bb.context(), bb.transformer()); // return exit in same context as block
}
}
}
bb.op(op);
return bb;
});
}
return ssaFunc;
}

// creates the access map
public static void buildAccessMap(MethodHandles.Lookup l, CoreOp.FuncOp f) {
// build blockParams so that we can map params to "root" params later
for (Body b : f.bodies()) {
for (Block block : b.blocks()) {
if (!block.parameters().isEmpty()) {
blockParams.put(block, block.parameters());
}
}
}

f.traverse(null, (map, op) -> {
if (op instanceof CoreOp.BranchOp b) {
mapBranch(l, b.branch());
} else if (op instanceof CoreOp.ConditionalBranchOp cb) {
mapBranch(l, cb.trueBranch()); // handle true branch
mapBranch(l, cb.falseBranch()); // handle false branch
} else if (op instanceof JavaOp.InvokeOp iop) { // (almost) all the buffer accesses happen here
if (isAssignable(l, iop.invokeDescriptor().refType(), MappableIface.class)) {
updateAccessType(getRootValue(iop), getAccessType(iop)); // update buffer access
if (isAssignable(l, iop.invokeDescriptor().refType(), Buffer.class)
&& iop.result() != null && !(iop.resultType() instanceof PrimitiveType)
&& isAssignable(l, iop.resultType(), MappableIface.class)) {
// if we access a struct/union from a buffer, we map the struct/union to the buffer root
remappedVals.put(iop.result(), getRootValue(iop));
}
}
} else if (op instanceof CoreOp.VarOp vop) { // map the new VarOp to the "root" param
if (isAssignable(l, vop.resultType().valueType(), Buffer.class)) {
remappedVals.put(vop.initOperand(), getRootValue(vop));
}
} else if (op instanceof JavaOp.FieldAccessOp.FieldLoadOp flop) {
if (isAssignable(l, flop.fieldDescriptor().refType(), KernelContext.class)) {
updateAccessType(getRootValue(flop), AccessType.RO); // handle kc access
}
}
return map;
});
}

// maps the parameters of a block to the values passed to a branch
public static void mapBranch(MethodHandles.Lookup l, Block.Reference b) {
List<Value> args = b.arguments();
for (int i = 0; i < args.size(); i++) {
Value key = blockParams.get(b.targetBlock()).get(i);
Value val = args.get(i);

if (val instanceof Op.Result) {
// either find root param or it doesnt exist (is a constant for example)
if (isAssignable(l, val.type(), MappableIface.class)) {
val = getRootValue(((Op.Result) val).op());
if (val instanceof Block.Parameter) {
val = remappedVals.getOrDefault(val, val);
}
}
}
remappedVals.put(key, val);
}
}

// checks if a TypeElement is assignable to a certain class
public static boolean isAssignable(MethodHandles.Lookup l, TypeElement type, Class<?> clazz) {
Class<?> fopClass = getClass(l, type);
return (fopClass != null && (clazz.isAssignableFrom(fopClass)));
}

// retrieves the class of a TypeElement
public static Class<?> getClass(MethodHandles.Lookup l, TypeElement type) {
if (type instanceof ClassType classType) {
try {
return (Class<?>) classType.resolve(l);
} catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
}
}
return null;
}

// retrieves "root" value of an op, the origin of the parameter (or value) used by the op
public static Value getRootValue(Op op) {
if (op.operands().isEmpty()) {
return op.result();
}
if (op.operands().getFirst() instanceof Block.Parameter param) {
return param;
}
Value val = op.operands().getFirst();
while (!(val instanceof Block.Parameter)) {
// or if the "root VarOp" is an invoke (not sure how to tell)
// if (tempOp instanceof JavaOp.InvokeOp iop
// && ((TypeElement) iop.resultType()) instanceof ClassType classType
// && !hasOperandType(iop, classType)) return ((CoreOp.VarOp) op);
val = ((Op.Result) val).op().operands().getFirst();
}
return val;
}

// retrieves accessType based on return value of InvokeOp
public static AccessType getAccessType(JavaOp.InvokeOp iop) {
return iop.invokeDescriptor().type().returnType().equals(JavaType.VOID) ? AccessType.WO : AccessType.RO;
}

// updates accessMap
public static void updateAccessType(Value val, AccessType curAccess) {
Value remappedVal = remappedVals.getOrDefault(val, val);
AccessType storedAccess = accessMap.get(remappedVal);
if (storedAccess == null) {
accessMap.put(remappedVal, curAccess);
} else if (curAccess != storedAccess && storedAccess != AccessType.RW) {
accessMap.put(remappedVal, AccessType.RW);
}
}

public static void printAccessMap() {
System.out.println("access map output:");
for (Value val : accessMap.keySet()) {
if (val instanceof Block.Parameter param) {
System.out.println("\t" + ((CoreOp.FuncOp) param.declaringBlock().parent().parent()).funcName()
+ " param w/ idx " + param.index() + ": " + accessMap.get(val));
} else {
System.out.println("\t" + val.toString() + ": " + accessMap.get(val));
}
}
}
}
6 changes: 6 additions & 0 deletions hat/core/src/main/java/hat/buffer/ArgArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
package hat.buffer;

import hat.Accelerator;
import hat.BufferTagger;
import hat.ComputeContext;
import hat.callgraph.KernelCallGraph;
import hat.ifacemapper.Schema;
Expand All @@ -34,6 +35,7 @@
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandles;
import java.nio.ByteOrder;
import java.util.ArrayList;

import static hat.buffer.ArgArray.Arg.Value.Buf.UNKNOWN_BYTE;
import static java.lang.foreign.ValueLayout.JAVA_BYTE;
Expand Down Expand Up @@ -289,6 +291,8 @@ static ArgArray create(Accelerator accelerator, KernelCallGraph kernelCallGraph,

static void update(ArgArray argArray, KernelCallGraph kernelCallGraph, Object... args) {
Annotation[][] parameterAnnotations = kernelCallGraph.entrypoint.getMethod().getParameterAnnotations();
ArrayList<BufferTagger.AccessType> bufferAccessList = kernelCallGraph.bufferAccessList;

for (int i = 0; i < args.length; i++) {
Object argObject = args[i];
Arg arg = argArray.arg(i); // this should be invariant, but if we are called from create it will be 0 for all
Expand Down Expand Up @@ -324,6 +328,8 @@ static void update(ArgArray argArray, KernelCallGraph kernelCallGraph, Object...
buf.address(segment);
buf.bytes(segment.byteSize());
buf.access(accessByte);

assert bufferAccessList.get(i).value == accessByte;
}
default -> throw new IllegalStateException("Unexpected value: " + argObject);
}
Expand Down
2 changes: 1 addition & 1 deletion hat/core/src/main/java/hat/callgraph/CallGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public abstract class CallGraph<E extends Entrypoint> {
public final Set<MethodCall> calls = new HashSet<>();
public final Map<MethodRef, MethodCall> methodRefToMethodCallMap = new LinkedHashMap<>();
public CoreOp.ModuleOp moduleOp;
public static boolean usingModuleOp = Boolean.getBoolean("moduleOp");
public static boolean noModuleOp = Boolean.getBoolean("noModuleOp");
public Stream<MethodCall> callStream() {
return methodRefToMethodCallMap.values().stream();
}
Expand Down
6 changes: 3 additions & 3 deletions hat/core/src/main/java/hat/callgraph/ComputeCallGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,10 @@ public void updateDag(ComputeReachableResolvedMethodCall computeReachableResolve
}

public void close() {
if (CallGraph.usingModuleOp) {
closeWithModuleOp(entrypoint);
} else {
if (CallGraph.noModuleOp) {
updateDag(entrypoint);
} else {
closeWithModuleOp(entrypoint);
}
}

Expand Down
3 changes: 3 additions & 0 deletions hat/core/src/main/java/hat/callgraph/KernelCallGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
*/
package hat.callgraph;

import hat.BufferTagger;
import hat.buffer.Buffer;
import hat.optools.OpTk;
import jdk.incubator.code.Op;
Expand All @@ -38,6 +39,7 @@
public class KernelCallGraph extends CallGraph<KernelEntrypoint> {
public final ComputeCallGraph computeCallGraph;
public final Map<MethodRef, MethodCall> bufferAccessToMethodCallMap = new LinkedHashMap<>();
public final ArrayList<BufferTagger.AccessType> bufferAccessList;

public interface KernelReachable {
}
Expand Down Expand Up @@ -77,6 +79,7 @@ public Stream<KernelReachableResolvedMethodCall> kernelReachableResolvedStream()
super(computeCallGraph.computeContext, new KernelEntrypoint(null, methodRef, method, funcOp));
entrypoint.callGraph = this;
this.computeCallGraph = computeCallGraph;
bufferAccessList = BufferTagger.getAccessList(computeContext.accelerator.lookup, entrypoint.funcOp());
}

void updateDag(KernelReachableResolvedMethodCall kernelReachableResolvedMethodCall) {
Expand Down
Loading