diff --git a/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaBackend.java b/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaBackend.java index 88db7a801db..99c68677770 100644 --- a/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaBackend.java +++ b/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaBackend.java @@ -412,14 +412,7 @@ String createPTX(KernelCallGraph kernelCallGraph, Object... args){ builder.ptxHeader(major, minor, target, addressSize); out.append(builder.getTextAndReset()); - if (CallGraph.usingModuleOp) { - System.out.println("Using ModuleOp for CudaBackend"); - kernelCallGraph.moduleOp.functionTable().forEach((_, funcOp) -> { - CoreOp.FuncOp loweredFunc = OpTk.lower(funcOp); - loweredFunc = transformPTXPtrs(kernelCallGraph.computeContext.accelerator.lookup,loweredFunc, argsMap, usedMathFns); - invokedMethods.append(createFunction(new PTXHATKernelBuilder(addressSize).nl().nl(), loweredFunc, false)); - }); - } else { + if (CallGraph.noModuleOp) { System.out.println("NOT using ModuleOp for CudaBackend"); for (KernelCallGraph.KernelReachableResolvedMethodCall k : kernelCallGraph.kernelReachableResolvedStream().toList()) { CoreOp.FuncOp calledFunc = k.funcOp(); @@ -427,6 +420,13 @@ String createPTX(KernelCallGraph kernelCallGraph, Object... args){ loweredFunc = transformPTXPtrs(kernelCallGraph.computeContext.accelerator.lookup,loweredFunc, argsMap, usedMathFns); invokedMethods.append(createFunction(new PTXHATKernelBuilder(addressSize).nl().nl(), loweredFunc, false)); } + } else { + System.out.println("Using ModuleOp for CudaBackend"); + kernelCallGraph.moduleOp.functionTable().forEach((_, funcOp) -> { + CoreOp.FuncOp loweredFunc = OpTk.lower(funcOp); + loweredFunc = transformPTXPtrs(kernelCallGraph.computeContext.accelerator.lookup,loweredFunc, argsMap, usedMathFns); + invokedMethods.append(createFunction(new PTXHATKernelBuilder(addressSize).nl().nl(), loweredFunc, false)); + }); } lowered = transformPTXPtrs(kernelCallGraph.computeContext.accelerator.lookup,lowered, argsMap, usedMathFns); diff --git a/hat/backends/ffi/mock/src/main/java/hat/backend/ffi/MockBackend.java b/hat/backends/ffi/mock/src/main/java/hat/backend/ffi/MockBackend.java index 8fbc3929d30..e933622445e 100644 --- a/hat/backends/ffi/mock/src/main/java/hat/backend/ffi/MockBackend.java +++ b/hat/backends/ffi/mock/src/main/java/hat/backend/ffi/MockBackend.java @@ -56,14 +56,14 @@ public void dispatchKernel(KernelCallGraph kernelCallGraph, NDRange ndRange, Obj // Here we receive a callgraph from the kernel entrypoint // The first time we see this we need to convert the kernel entrypoint // and rechable methods to a form that our mock backend can execute. - if (CallGraph.usingModuleOp) { - System.out.println("Using ModuleOp for MockBackend"); - kernelCallGraph.moduleOp.functionTable().forEach((_, funcOp) -> { - }); - } else { + if (CallGraph.noModuleOp) { System.out.println("NOT using ModuleOp for MockBackend"); kernelCallGraph.kernelReachableResolvedStream().forEach(kr -> { + }); + } else { + System.out.println("Using ModuleOp for MockBackend"); + kernelCallGraph.moduleOp.functionTable().forEach((_, funcOp) -> { }); } } diff --git a/hat/backends/ffi/shared/src/main/java/hat/backend/ffi/C99FFIBackend.java b/hat/backends/ffi/shared/src/main/java/hat/backend/ffi/C99FFIBackend.java index fbe1a6475aa..46e6fa52f53 100644 --- a/hat/backends/ffi/shared/src/main/java/hat/backend/ffi/C99FFIBackend.java +++ b/hat/backends/ffi/shared/src/main/java/hat/backend/ffi/C99FFIBackend.java @@ -233,14 +233,7 @@ public > String createCode(KernelCallGraph kern kernelCallGraph.entrypoint.funcOp()); // Sorting by rank ensures we don't need forward declarations - if (CallGraph.usingModuleOp) { - System.out.println("Using ModuleOp for C99FFIBackend"); - kernelCallGraph.moduleOp.functionTable() - .forEach((_, funcOp) -> builder - .nl() - .kernelMethod(buildContext,funcOp) - .nl()); - } else { + if (CallGraph.noModuleOp) { System.out.println("NOT using ModuleOp for C99FFIBackend"); kernelCallGraph.kernelReachableResolvedStream().sorted((lhs, rhs) -> rhs.rank - lhs.rank) .forEach(kernelReachableResolvedMethod -> @@ -248,6 +241,13 @@ public > String createCode(KernelCallGraph kern .nl() .kernelMethod(buildContext,kernelReachableResolvedMethod.funcOp()) .nl()); + } else { + System.out.println("Using ModuleOp for C99FFIBackend"); + kernelCallGraph.moduleOp.functionTable() + .forEach((_, funcOp) -> builder + .nl() + .kernelMethod(buildContext,funcOp) + .nl()); } builder.nl().kernelEntrypoint(buildContext, args).nl(); diff --git a/hat/core/src/main/java/hat/BufferTagger.java b/hat/core/src/main/java/hat/BufferTagger.java new file mode 100644 index 00000000000..f0341f9dc8a --- /dev/null +++ b/hat/core/src/main/java/hat/BufferTagger.java @@ -0,0 +1,214 @@ +package hat; + +import hat.buffer.Buffer; +import hat.ifacemapper.MappableIface; +import jdk.incubator.code.*; +import jdk.incubator.code.analysis.Inliner; +import jdk.incubator.code.analysis.SSA; +import jdk.incubator.code.dialect.core.CoreOp; +import jdk.incubator.code.dialect.java.*; + +import java.lang.invoke.MethodHandles; +import java.lang.reflect.Method; +import java.util.*; +import java.util.concurrent.atomic.AtomicBoolean; + +public class BufferTagger { + static HashMap accessMap = new HashMap<>(); + static HashMap remappedVals = new HashMap<>(); // maps values to their "root" parameter/value + static HashMap> blockParams = new HashMap<>(); // holds block parameters for easy lookup + + public enum AccessType { + NA(1), + RO(2), + WO(4), + RW(6), + NOT_BUFFER(0); + + public final int value; + AccessType(int i) { + value = i; + } + } + + // generates a list of AccessTypes matching the given FuncOp's parameter order + public static ArrayList getAccessList(MethodHandles.Lookup l, CoreOp.FuncOp f) { + CoreOp.FuncOp inlinedFunc = inlineLoop(l, f); + buildAccessMap(l, inlinedFunc); + ArrayList accessList = new ArrayList<>(); + for (Block.Parameter p : inlinedFunc.body().entryBlock().parameters()) { + if (accessMap.containsKey(p)) { + accessList.add(accessMap.get(p)); // is an accessed buffer + } else if (getClass(l, p.type()) instanceof Class c && MappableIface.class.isAssignableFrom(c)) { + accessList.add(AccessType.NA); // is a buffer but not accessed + } else { + accessList.add(AccessType.NOT_BUFFER); // is not a buffer + } + } + return accessList; + } + + // inlines functions found in FuncOp f until no more inline-able functions are present + public static CoreOp.FuncOp inlineLoop(MethodHandles.Lookup l, CoreOp.FuncOp f) { + CoreOp.FuncOp ssaFunc = SSA.transform(f.transform(OpTransformer.LOWERING_TRANSFORMER)); + AtomicBoolean changed = new AtomicBoolean(true); + while (changed.get()) { // loop until no more inline-able functions + changed.set(false); + ssaFunc = ssaFunc.transform((bb, op) -> { + if (op instanceof JavaOp.InvokeOp iop) { + MethodRef methodRef = iop.invokeDescriptor(); + Method invokeOpCalledMethod; + try { + invokeOpCalledMethod = methodRef.resolveToMethod(l, iop.invokeKind()); + } catch (ReflectiveOperationException _) { + throw new IllegalStateException("Could not resolve invokeOp to method"); + } + if (invokeOpCalledMethod instanceof Method method) { // if method isn't a buffer access (is code reflected) + if (Op.ofMethod(method).isPresent()) { + CoreOp.FuncOp inline = Op.ofMethod(method).get(); // method to be inlined + CoreOp.FuncOp ssaInline = SSA.transform(inline.transform(OpTransformer.LOWERING_TRANSFORMER)); + + Block.Builder exit = Inliner.inline(bb, ssaInline, bb.context().getValues(iop.operands()), (_, v) -> { + if (v != null) bb.context().mapValue(iop.result(), v); + }); + + if (!exit.parameters().isEmpty()) { + bb.context().mapValue(iop.result(), exit.parameters().getFirst()); + } + changed.set(true); + return exit.rebind(bb.context(), bb.transformer()); // return exit in same context as block + } + } + } + bb.op(op); + return bb; + }); + } + return ssaFunc; + } + + // creates the access map + public static void buildAccessMap(MethodHandles.Lookup l, CoreOp.FuncOp f) { + // build blockParams so that we can map params to "root" params later + for (Body b : f.bodies()) { + for (Block block : b.blocks()) { + if (!block.parameters().isEmpty()) { + blockParams.put(block, block.parameters()); + } + } + } + + f.traverse(null, (map, op) -> { + if (op instanceof CoreOp.BranchOp b) { + mapBranch(l, b.branch()); + } else if (op instanceof CoreOp.ConditionalBranchOp cb) { + mapBranch(l, cb.trueBranch()); // handle true branch + mapBranch(l, cb.falseBranch()); // handle false branch + } else if (op instanceof JavaOp.InvokeOp iop) { // (almost) all the buffer accesses happen here + if (isAssignable(l, iop.invokeDescriptor().refType(), MappableIface.class)) { + updateAccessType(getRootValue(iop), getAccessType(iop)); // update buffer access + if (isAssignable(l, iop.invokeDescriptor().refType(), Buffer.class) + && iop.result() != null && !(iop.resultType() instanceof PrimitiveType) + && isAssignable(l, iop.resultType(), MappableIface.class)) { + // if we access a struct/union from a buffer, we map the struct/union to the buffer root + remappedVals.put(iop.result(), getRootValue(iop)); + } + } + } else if (op instanceof CoreOp.VarOp vop) { // map the new VarOp to the "root" param + if (isAssignable(l, vop.resultType().valueType(), Buffer.class)) { + remappedVals.put(vop.initOperand(), getRootValue(vop)); + } + } else if (op instanceof JavaOp.FieldAccessOp.FieldLoadOp flop) { + if (isAssignable(l, flop.fieldDescriptor().refType(), KernelContext.class)) { + updateAccessType(getRootValue(flop), AccessType.RO); // handle kc access + } + } + return map; + }); + } + + // maps the parameters of a block to the values passed to a branch + public static void mapBranch(MethodHandles.Lookup l, Block.Reference b) { + List args = b.arguments(); + for (int i = 0; i < args.size(); i++) { + Value key = blockParams.get(b.targetBlock()).get(i); + Value val = args.get(i); + + if (val instanceof Op.Result) { + // either find root param or it doesnt exist (is a constant for example) + if (isAssignable(l, val.type(), MappableIface.class)) { + val = getRootValue(((Op.Result) val).op()); + if (val instanceof Block.Parameter) { + val = remappedVals.getOrDefault(val, val); + } + } + } + remappedVals.put(key, val); + } + } + + // checks if a TypeElement is assignable to a certain class + public static boolean isAssignable(MethodHandles.Lookup l, TypeElement type, Class clazz) { + Class fopClass = getClass(l, type); + return (fopClass != null && (clazz.isAssignableFrom(fopClass))); + } + + // retrieves the class of a TypeElement + public static Class getClass(MethodHandles.Lookup l, TypeElement type) { + if (type instanceof ClassType classType) { + try { + return (Class) classType.resolve(l); + } catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + } + return null; + } + + // retrieves "root" value of an op, the origin of the parameter (or value) used by the op + public static Value getRootValue(Op op) { + if (op.operands().isEmpty()) { + return op.result(); + } + if (op.operands().getFirst() instanceof Block.Parameter param) { + return param; + } + Value val = op.operands().getFirst(); + while (!(val instanceof Block.Parameter)) { + // or if the "root VarOp" is an invoke (not sure how to tell) + // if (tempOp instanceof JavaOp.InvokeOp iop + // && ((TypeElement) iop.resultType()) instanceof ClassType classType + // && !hasOperandType(iop, classType)) return ((CoreOp.VarOp) op); + val = ((Op.Result) val).op().operands().getFirst(); + } + return val; + } + + // retrieves accessType based on return value of InvokeOp + public static AccessType getAccessType(JavaOp.InvokeOp iop) { + return iop.invokeDescriptor().type().returnType().equals(JavaType.VOID) ? AccessType.WO : AccessType.RO; + } + + // updates accessMap + public static void updateAccessType(Value val, AccessType curAccess) { + Value remappedVal = remappedVals.getOrDefault(val, val); + AccessType storedAccess = accessMap.get(remappedVal); + if (storedAccess == null) { + accessMap.put(remappedVal, curAccess); + } else if (curAccess != storedAccess && storedAccess != AccessType.RW) { + accessMap.put(remappedVal, AccessType.RW); + } + } + + public static void printAccessMap() { + System.out.println("access map output:"); + for (Value val : accessMap.keySet()) { + if (val instanceof Block.Parameter param) { + System.out.println("\t" + ((CoreOp.FuncOp) param.declaringBlock().parent().parent()).funcName() + + " param w/ idx " + param.index() + ": " + accessMap.get(val)); + } else { + System.out.println("\t" + val.toString() + ": " + accessMap.get(val)); + } + } + } +} \ No newline at end of file diff --git a/hat/core/src/main/java/hat/buffer/ArgArray.java b/hat/core/src/main/java/hat/buffer/ArgArray.java index fb217c03b4d..e78f6729a54 100644 --- a/hat/core/src/main/java/hat/buffer/ArgArray.java +++ b/hat/core/src/main/java/hat/buffer/ArgArray.java @@ -25,6 +25,7 @@ package hat.buffer; import hat.Accelerator; +import hat.BufferTagger; import hat.ComputeContext; import hat.callgraph.KernelCallGraph; import hat.ifacemapper.Schema; @@ -34,6 +35,7 @@ import java.lang.foreign.ValueLayout; import java.lang.invoke.MethodHandles; import java.nio.ByteOrder; +import java.util.ArrayList; import static hat.buffer.ArgArray.Arg.Value.Buf.UNKNOWN_BYTE; import static java.lang.foreign.ValueLayout.JAVA_BYTE; @@ -289,6 +291,8 @@ static ArgArray create(Accelerator accelerator, KernelCallGraph kernelCallGraph, static void update(ArgArray argArray, KernelCallGraph kernelCallGraph, Object... args) { Annotation[][] parameterAnnotations = kernelCallGraph.entrypoint.getMethod().getParameterAnnotations(); + ArrayList bufferAccessList = kernelCallGraph.bufferAccessList; + for (int i = 0; i < args.length; i++) { Object argObject = args[i]; Arg arg = argArray.arg(i); // this should be invariant, but if we are called from create it will be 0 for all @@ -324,6 +328,8 @@ static void update(ArgArray argArray, KernelCallGraph kernelCallGraph, Object... buf.address(segment); buf.bytes(segment.byteSize()); buf.access(accessByte); + + assert bufferAccessList.get(i).value == accessByte; } default -> throw new IllegalStateException("Unexpected value: " + argObject); } diff --git a/hat/core/src/main/java/hat/callgraph/CallGraph.java b/hat/core/src/main/java/hat/callgraph/CallGraph.java index 5788e4fae2a..c3cf9ff6d47 100644 --- a/hat/core/src/main/java/hat/callgraph/CallGraph.java +++ b/hat/core/src/main/java/hat/callgraph/CallGraph.java @@ -42,7 +42,7 @@ public abstract class CallGraph { public final Set calls = new HashSet<>(); public final Map methodRefToMethodCallMap = new LinkedHashMap<>(); public CoreOp.ModuleOp moduleOp; - public static boolean usingModuleOp = Boolean.getBoolean("moduleOp"); + public static boolean noModuleOp = Boolean.getBoolean("noModuleOp"); public Stream callStream() { return methodRefToMethodCallMap.values().stream(); } diff --git a/hat/core/src/main/java/hat/callgraph/ComputeCallGraph.java b/hat/core/src/main/java/hat/callgraph/ComputeCallGraph.java index 02d61c9c7ab..f1c29fb6a72 100644 --- a/hat/core/src/main/java/hat/callgraph/ComputeCallGraph.java +++ b/hat/core/src/main/java/hat/callgraph/ComputeCallGraph.java @@ -216,10 +216,10 @@ public void updateDag(ComputeReachableResolvedMethodCall computeReachableResolve } public void close() { - if (CallGraph.usingModuleOp) { - closeWithModuleOp(entrypoint); - } else { + if (CallGraph.noModuleOp) { updateDag(entrypoint); + } else { + closeWithModuleOp(entrypoint); } } diff --git a/hat/core/src/main/java/hat/callgraph/KernelCallGraph.java b/hat/core/src/main/java/hat/callgraph/KernelCallGraph.java index 26c63bc6430..3222233e494 100644 --- a/hat/core/src/main/java/hat/callgraph/KernelCallGraph.java +++ b/hat/core/src/main/java/hat/callgraph/KernelCallGraph.java @@ -24,6 +24,7 @@ */ package hat.callgraph; +import hat.BufferTagger; import hat.buffer.Buffer; import hat.optools.OpTk; import jdk.incubator.code.Op; @@ -38,6 +39,7 @@ public class KernelCallGraph extends CallGraph { public final ComputeCallGraph computeCallGraph; public final Map bufferAccessToMethodCallMap = new LinkedHashMap<>(); + public final ArrayList bufferAccessList; public interface KernelReachable { } @@ -77,6 +79,7 @@ public Stream kernelReachableResolvedStream() super(computeCallGraph.computeContext, new KernelEntrypoint(null, methodRef, method, funcOp)); entrypoint.callGraph = this; this.computeCallGraph = computeCallGraph; + bufferAccessList = BufferTagger.getAccessList(computeContext.accelerator.lookup, entrypoint.funcOp()); } void updateDag(KernelReachableResolvedMethodCall kernelReachableResolvedMethodCall) { diff --git a/hat/hat/Script.java b/hat/hat/Script.java index e2f38e863c0..c542118e41a 100644 --- a/hat/hat/Script.java +++ b/hat/hat/Script.java @@ -1284,7 +1284,7 @@ public static final class JavaBuilder extends JavaToolBuilder { public StringList args = new StringList(); public StringList nativeAccessModules = new StringList(); private boolean headless; - public boolean moduleOp; + public boolean noModuleOp; public JavaBuilder enable_native_access(String module) { @@ -1344,8 +1344,8 @@ public void headless() { this.headless = true; } - public void moduleOp() { - this.moduleOp = true; + public void noModuleOp() { + this.noModuleOp = true; } } @@ -1378,8 +1378,8 @@ public static JavaBuilder java(JavaBuilder javaBuilder) { if (javaBuilder.headless) { result.opts.add("-Dheadless=true"); } - if (javaBuilder.moduleOp) { - result.opts.add("-DmoduleOp=true"); + if (javaBuilder.noModuleOp) { + result.opts.add("-DnoModuleOp=true"); } if (javaBuilder.startOnFirstThread) { result.opts.add("-XstartOnFirstThread"); diff --git a/hat/hat/run.java b/hat/hat/run.java index f8d620bea4b..5bd02353596 100644 --- a/hat/hat/run.java +++ b/hat/hat/run.java @@ -28,7 +28,7 @@ class Config{ boolean headless=false; - boolean moduleOp = false; + boolean noModuleOp = false; boolean verbose = false; boolean startOnFirstThread = false; boolean justShowCommandline = false; @@ -72,7 +72,7 @@ class Config{ }else{ switch (args[arg]) { case "headless" -> headless = true; - case "moduleOp" -> moduleOp = true; + case "noModuleOp" -> noModuleOp = true; case "verbose" -> verbose = true; case "justShowCommandLine" -> justShowCommandline = true; case "startOnFirstThread" -> startOnFirstThread = true; @@ -172,7 +172,7 @@ class name is assumed to be package.Main (i.e. mandel.main) } default -> {} } - if (config.moduleOp) System.out.println("Using ModuleOp for CallGraphs"); + if (config.noModuleOp) System.out.println("NOT using ModuleOp for CallGraphs"); } Script.java(java -> java .enable_preview() @@ -180,7 +180,7 @@ class name is assumed to be package.Main (i.e. mandel.main) .enable_native_access("ALL-UNNAMED") .library_path(buildDir) .when(config.headless, Script.JavaBuilder::headless) - .when(config.moduleOp, Script.JavaBuilder::moduleOp) + .when(config.noModuleOp, Script.JavaBuilder::noModuleOp) .when(config.startOnFirstThread, Script.JavaBuilder::start_on_first_thread) .class_path(config.classpath) .vmargs(config.vmargs)