diff --git a/config/opal_check_cudart.m4 b/config/opal_check_cudart.m4 new file mode 100644 index 00000000000..0e3fced8065 --- /dev/null +++ b/config/opal_check_cudart.m4 @@ -0,0 +1,120 @@ +dnl -*- autoconf -*- +dnl +dnl Copyright (c) 2004-2010 The Trustees of Indiana University and Indiana +dnl University Research and Technology +dnl Corporation. All rights reserved. +dnl Copyright (c) 2004-2005 The University of Tennessee and The University +dnl of Tennessee Research Foundation. All rights +dnl reserved. +dnl Copyright (c) 2004-2005 High Performance Computing Center Stuttgart, +dnl University of Stuttgart. All rights reserved. +dnl Copyright (c) 2004-2005 The Regents of the University of California. +dnl All rights reserved. +dnl Copyright (c) 2006-2016 Cisco Systems, Inc. All rights reserved. +dnl Copyright (c) 2007 Sun Microsystems, Inc. All rights reserved. +dnl Copyright (c) 2009 IBM Corporation. All rights reserved. +dnl Copyright (c) 2009 Los Alamos National Security, LLC. All rights +dnl reserved. +dnl Copyright (c) 2009-2011 Oak Ridge National Labs. All rights reserved. +dnl Copyright (c) 2011-2015 NVIDIA Corporation. All rights reserved. +dnl Copyright (c) 2015 Research Organization for Information Science +dnl and Technology (RIST). All rights reserved. +dnl Copyright (c) 2022 Amazon.com, Inc. or its affiliates. All Rights reserved. +dnl $COPYRIGHT$ +dnl +dnl Additional copyrights may follow +dnl +dnl $HEADER$ +dnl + + +# OPAL_CHECK_CUDART(prefix, [action-if-found], [action-if-not-found]) +# -------------------------------------------------------- +# check if CUDA runtime library support can be found. sets prefix_{CPPFLAGS, +# LDFLAGS, LIBS} as needed and runs action-if-found if there is +# support, otherwise executes action-if-not-found + +# +# Check for CUDA support +# +AC_DEFUN([OPAL_CHECK_CUDART],[ +OPAL_VAR_SCOPE_PUSH([cudart_save_CPPFLAGS cudart_save_LDFLAGS cudart_save_LIBS]) + +cudart_save_CPPFLAGS="$CPPFLAGS" +cudart_save_LDFLAGS="$LDFLAGS" +cudart_save_LIBS="$LIBS" + +# +# Check to see if the user provided paths for CUDART +# +AC_ARG_WITH([cudart], + [AS_HELP_STRING([--with-cudart=DIR], + [Path to the CUDA runtime library and header files])]) +AC_MSG_CHECKING([if --with-cudart is set]) +AC_ARG_WITH([cudart-libdir], + [AS_HELP_STRING([--with-cudart-libdir=DIR], + [Search for CUDA runtime libraries in DIR])]) + +#################################### +#### Check for CUDA runtime library +#################################### +AS_IF([test "x$with_cudart" != "xno" || test "x$with_cudart" = "x"], + [opal_check_cudart_happy=no + AC_MSG_RESULT([not set (--with-cudart=$with_cudart)])], + [AS_IF([test ! -d "$with_cudart"], + [AC_MSG_RESULT([not found]) + AC_MSG_WARN([Directory $with_cudart not found])] + [AS_IF([test "x`ls $with_cudart/include/cuda_runtime.h 2> /dev/null`" = "x"] + [AC_MSG_RESULT([not found]) + AC_MSG_WARN([Could not find cuda_runtime.h in $with_cudart/include])] + [opal_check_cudart_happy=yes + opal_cudart_incdir="$with_cudart/include"])])]) + +AS_IF([test "$opal_check_cudart_happy" = "no" && test "$with_cudart" != "no"], + [AC_PATH_PROG([nvcc_bin], [nvcc], ["not-found"]) + AS_IF([test "$nvcc_bin" = "not-found"], + [AC_MSG_WARN([Could not find nvcc binary])], + [nvcc_dirname=`AS_DIRNAME([$nvcc_bin])` + with_cudart=$nvcc_dirname/../ + opal_cudart_incdir=$nvcc_dirname/../include + opal_check_cudart_happy=yes]) + ] + []) + +AS_IF([test x"$with_cudart_libdir" = "x"], + [with_cudart_libdir=$with_cudart/lib64/] + []) + +AS_IF([test "$opal_check_cudart_happy" = "yes"], + [OAC_CHECK_PACKAGE([cudart], + [$1], + [cuda_runtime.h], + [cudart], + [cudaMalloc], + [opal_check_cudart_happy="yes"], + [opal_check_cudart_happy="no"])], + []) + + +AC_MSG_CHECKING([if have cuda runtime library support]) +if test "$opal_check_cudart_happy" = "yes"; then + AC_MSG_RESULT([yes (-I$opal_cudart_incdir)]) + CUDART_SUPPORT=1 + common_cudart_CPPFLAGS="-I$opal_cudart_incdir" + AC_SUBST([common_cudart_CPPFLAGS]) +else + AC_MSG_RESULT([no]) + CUDART_SUPPORT=0 +fi + + +OPAL_SUMMARY_ADD([Accelerators], [CUDART support], [], [$opal_check_cudart_happy]) +AM_CONDITIONAL([OPAL_cudart_support], [test "x$CUDART_SUPPORT" = "x1"]) +AC_DEFINE_UNQUOTED([OPAL_CUDART_SUPPORT],$CUDART_SUPPORT, + [Whether we have cuda runtime library support]) + +CPPFLAGS=${cudart_save_CPPFLAGS} +LDFLAGS=${cudart_save_LDFLAGS} +LIBS=${cudart_save_LIBS} +OPAL_VAR_SCOPE_POP +])dnl diff --git a/ompi/datatype/ompi_datatype.h b/ompi/datatype/ompi_datatype.h index 0c77079b916..5069b7e90a5 100644 --- a/ompi/datatype/ompi_datatype.h +++ b/ompi/datatype/ompi_datatype.h @@ -275,8 +275,9 @@ ompi_datatype_set_element_count( const ompi_datatype_t* type, size_t count, size } static inline int32_t -ompi_datatype_copy_content_same_ddt( const ompi_datatype_t* type, size_t count, - char* pDestBuf, char* pSrcBuf ) +ompi_datatype_copy_content_same_ddt_stream( const ompi_datatype_t* type, size_t count, + char* pDestBuf, char* pSrcBuf, + opal_accelerator_stream_t *stream ) { int32_t length, rc; ptrdiff_t extent; @@ -285,8 +286,8 @@ ompi_datatype_copy_content_same_ddt( const ompi_datatype_t* type, size_t count, while( 0 != count ) { length = INT_MAX; if( ((size_t)length) > count ) length = (int32_t)count; - rc = opal_datatype_copy_content_same_ddt( &type->super, length, - pDestBuf, pSrcBuf ); + rc = opal_datatype_copy_content_same_ddt_stream( &type->super, length, + pDestBuf, pSrcBuf, stream ); if( 0 != rc ) return rc; pDestBuf += ((ptrdiff_t)length) * extent; pSrcBuf += ((ptrdiff_t)length) * extent; @@ -295,6 +296,13 @@ ompi_datatype_copy_content_same_ddt( const ompi_datatype_t* type, size_t count, return 0; } +static inline int32_t +ompi_datatype_copy_content_same_ddt( const ompi_datatype_t* type, size_t count, + char* pDestBuf, char* pSrcBuf ) +{ + return ompi_datatype_copy_content_same_ddt_stream(type, count, pDestBuf, pSrcBuf, NULL); +} + OMPI_DECLSPEC const ompi_datatype_t* ompi_datatype_match_size( int size, uint16_t datakind, uint16_t datalang ); /* diff --git a/ompi/mca/coll/base/coll_base_allreduce.c b/ompi/mca/coll/base/coll_base_allreduce.c index 30ab0a4f869..fdc77d1f416 100644 --- a/ompi/mca/coll/base/coll_base_allreduce.c +++ b/ompi/mca/coll/base/coll_base_allreduce.c @@ -140,6 +140,7 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, int ret, line, rank, size, adjsize, remote, distance; int newrank, newremote, extra_ranks; char *tmpsend = NULL, *tmprecv = NULL, *tmpswap = NULL, *inplacebuf_free = NULL, *inplacebuf; + char *recvbuf = NULL; ptrdiff_t span, gap = 0; size = ompi_comm_size(comm); @@ -157,22 +158,64 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, return MPI_SUCCESS; } - /* Allocate and initialize temporary send buffer */ + /* get the device for sbuf and rbuf and where the op would like to execute */ + int sendbuf_dev, recvbuf_dev, op_dev; + uint64_t sendbuf_flags, recvbuf_flags; + ompi_coll_base_select_device(op, sbuf, rbuf, count, dtype, &sendbuf_dev, &recvbuf_dev, + &sendbuf_flags, &recvbuf_flags, &op_dev); span = opal_datatype_span(&dtype->super, count, &gap); - inplacebuf_free = (char*) malloc(span); + inplacebuf_free = ompi_coll_base_allocate_on_device(op_dev, span, module); if (NULL == inplacebuf_free) { ret = -1; line = __LINE__; goto error_hndl; } inplacebuf = inplacebuf_free - gap; + //printf("allreduce ring count %d sbuf_dev %d rbuf_dev %d op_dev %d\n", count, sendbuf_dev, recvbuf_dev, op_dev); - if (MPI_IN_PLACE == sbuf) { - ret = ompi_datatype_copy_content_same_ddt(dtype, count, inplacebuf, (char*)rbuf); - if (ret < 0) { line = __LINE__; goto error_hndl; } - } else { - ret = ompi_datatype_copy_content_same_ddt(dtype, count, inplacebuf, (char*)sbuf); + opal_accelerator_stream_t *stream = NULL; + if (op_dev >= 0) { + opal_accelerator.get_default_stream(op_dev, &stream); + } + + tmpsend = (char*) sbuf; + if (op_dev != recvbuf_dev) { + /* copy data to where the op wants it to be */ + if (MPI_IN_PLACE == sbuf) { + ret = ompi_datatype_copy_content_same_ddt_stream(dtype, count, inplacebuf, (char*)rbuf, stream); + if (ret < 0) { line = __LINE__; goto error_hndl; } + } + /* only copy if op is on the device or we cannot access the sendbuf on the host */ + else if (op_dev != MCA_ACCELERATOR_NO_DEVICE_ID || + 0 == (sendbuf_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY)) { + ret = ompi_datatype_copy_content_same_ddt_stream(dtype, count, inplacebuf, (char*)sbuf, stream); + if (ret < 0) { line = __LINE__; goto error_hndl; } + } + tmpsend = (char*) inplacebuf; + } else if (MPI_IN_PLACE == sbuf) { + ret = ompi_datatype_copy_content_same_ddt_stream(dtype, count, inplacebuf, (char*)rbuf, stream); if (ret < 0) { line = __LINE__; goto error_hndl; } + tmpsend = (char*) inplacebuf; } - tmpsend = (char*) inplacebuf; - tmprecv = (char*) rbuf; + /* Handle MPI_IN_PLACE */ + bool use_sbuf = (MPI_IN_PLACE != sbuf); + /* allocate temporary recv buffer if the tmpbuf above is on a different device than the rbuf + * and the op is on the device or we cannot access the recv buffer on the host */ + recvbuf = rbuf; + bool free_recvbuf = false; + if (op_dev != recvbuf_dev && + (op_dev != MCA_ACCELERATOR_NO_DEVICE_ID || + 0 == (recvbuf_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY))) { + recvbuf = ompi_coll_base_allocate_on_device(op_dev, span, module); + free_recvbuf = true; + if (use_sbuf) { + /* copy from rbuf */ + ompi_datatype_copy_content_same_ddt_stream(dtype, count, (char*)recvbuf, (char*)sbuf, stream); + } else { + /* copy from sbuf */ + ompi_datatype_copy_content_same_ddt_stream(dtype, count, (char*)recvbuf, (char*)rbuf, stream); + } + use_sbuf = false; + } + + tmprecv = (char*) recvbuf; /* Determine nearest power of two less than or equal to size */ adjsize = opal_next_poweroftwo (size); @@ -188,6 +231,11 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, extra_ranks = size - adjsize; if (rank < (2 * extra_ranks)) { if (0 == (rank % 2)) { + /* wait for above copies to complete */ + if (NULL != stream) { + opal_accelerator.wait_stream(stream); + } + /* wait for tmpsend to be copied */ ret = MCA_PML_CALL(send(tmpsend, count, dtype, (rank + 1), MCA_COLL_BASE_TAG_ALLREDUCE, MCA_PML_BASE_SEND_STANDARD, comm)); @@ -198,8 +246,14 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, MCA_COLL_BASE_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE)); if (MPI_SUCCESS != ret) { line = __LINE__; goto error_hndl; } - /* tmpsend = tmprecv (op) tmpsend */ - ompi_op_reduce(op, tmprecv, tmpsend, count, dtype); + if (tmpsend == sbuf) { + tmpsend = inplacebuf; + /* tmpsend = tmprecv (op) sbuf */ + ompi_3buff_op_reduce_stream(op, sbuf, tmprecv, tmpsend, count, dtype, op_dev, stream); + } else { + /* tmpsend = tmprecv (op) tmpsend */ + ompi_op_reduce_stream(op, tmprecv, tmpsend, count, dtype, op_dev, stream); + } newrank = rank >> 1; } } else { @@ -218,6 +272,12 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, remote = (newremote < extra_ranks)? (newremote * 2 + 1):(newremote + extra_ranks); + bool have_next_iter = ((distance << 1) < adjsize); + + /* wait for previous ops to complete to complete */ + if (NULL != stream) { + opal_accelerator.wait_stream(stream); + } /* Exchange the data */ ret = ompi_coll_base_sendrecv_actual(tmpsend, count, dtype, remote, MCA_COLL_BASE_TAG_ALLREDUCE, @@ -228,14 +288,47 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, /* Apply operation */ if (rank < remote) { - /* tmprecv = tmpsend (op) tmprecv */ - ompi_op_reduce(op, tmpsend, tmprecv, count, dtype); - tmpswap = tmprecv; - tmprecv = tmpsend; - tmpsend = tmpswap; + if (tmpsend == sbuf) { + /* special case: 1st iteration takes one input from the sbuf */ + /* tmprecv = sbuf (op) tmprecv */ + ompi_op_reduce_stream(op, sbuf, tmprecv, count, dtype, op_dev, stream); + /* send the current recv buffer, and use the tmp buffer to receive */ + tmpsend = tmprecv; + tmprecv = inplacebuf; + } else if (have_next_iter || tmprecv == recvbuf) { + /* All iterations, and the last if tmprecv is the recv buffer */ + /* tmprecv = tmpsend (op) tmprecv */ + ompi_op_reduce_stream(op, tmpsend, tmprecv, count, dtype, op_dev, stream); + /* swap send and receive buffers */ + tmpswap = tmprecv; + tmprecv = tmpsend; + tmpsend = tmpswap; + } else { + /* Last iteration if tmprecv is not the recv buffer, then tmpsend is */ + /* Make sure we reduce into the receive buffer + * tmpsend = tmprecv (op) tmpsend */ + ompi_op_reduce_stream(op, tmprecv, tmpsend, count, dtype, op_dev, stream); + } } else { - /* tmpsend = tmprecv (op) tmpsend */ - ompi_op_reduce(op, tmprecv, tmpsend, count, dtype); + if (tmpsend == sbuf) { + /* First iteration: use input from sbuf */ + /* tmpsend = tmprecv (op) sbuf */ + tmpsend = inplacebuf; + if (have_next_iter || tmpsend == recvbuf) { + ompi_3buff_op_reduce_stream(op, tmprecv, sbuf, tmpsend, count, dtype, op_dev, stream); + } else { + ompi_op_reduce_stream(op, sbuf, tmprecv, count, dtype, op_dev, stream); + tmpsend = tmprecv; + } + } else if (have_next_iter || tmpsend == recvbuf) { + /* All other iterations: reduce into tmpsend for next iteration */ + /* tmpsend = tmprecv (op) tmpsend */ + ompi_op_reduce_stream(op, tmprecv, tmpsend, count, dtype, op_dev, stream); + } else { + /* Last iteration: reduce into rbuf and set tmpsend to rbuf (needed at the end) */ + ompi_op_reduce_stream(op, tmpsend, tmprecv, count, dtype, op_dev, stream); + tmpsend = tmprecv; + } } } @@ -252,6 +345,10 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, if (MPI_SUCCESS != ret) { line = __LINE__; goto error_hndl; } tmpsend = (char*)rbuf; } else { + /* wait for previous ops to complete */ + if (NULL != stream) { + opal_accelerator.wait_stream(stream); + } ret = MCA_PML_CALL(send(tmpsend, count, dtype, (rank - 1), MCA_COLL_BASE_TAG_ALLREDUCE, MCA_PML_BASE_SEND_STANDARD, comm)); @@ -261,18 +358,31 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, /* Ensure that the final result is in rbuf */ if (tmpsend != rbuf) { - ret = ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, tmpsend); + /* TODO: catch this case in the 3buf selection above. Maybe already caught? */ + ret = ompi_datatype_copy_content_same_ddt_stream(dtype, count, (char*)rbuf, tmpsend, stream); if (ret < 0) { line = __LINE__; goto error_hndl; } } - if (NULL != inplacebuf_free) free(inplacebuf_free); + /* wait for previous ops to complete */ + if (NULL != stream) { + opal_accelerator.wait_stream(stream); + } + ompi_coll_base_free_tmpbuf(inplacebuf_free, op_dev, module); + + if (free_recvbuf) { + ompi_coll_base_free_tmpbuf(recvbuf, op_dev, module); + } return MPI_SUCCESS; error_hndl: OPAL_OUTPUT((ompi_coll_base_framework.framework_output, "%s:%4d\tRank %d Error occurred %d\n", __FILE__, line, rank, ret)); (void)line; // silence compiler warning - if (NULL != inplacebuf_free) free(inplacebuf_free); + ompi_coll_base_free_tmpbuf(inplacebuf_free, op_dev, module); + + if (op_dev != recvbuf_dev) { + ompi_coll_base_free_tmpbuf(recvbuf, op_dev, module); + } return ret; } @@ -351,6 +461,7 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, int count, int early_segcount, late_segcount, split_rank, max_segcount; size_t typelng; char *tmpsend = NULL, *tmprecv = NULL, *inbuf[2] = {NULL, NULL}; + void *recvbuf = NULL; ptrdiff_t true_lb, true_extent, lb, extent; ptrdiff_t block_offset, max_real_segsize; ompi_request_t *reqs[2] = {MPI_REQUEST_NULL, MPI_REQUEST_NULL}; @@ -399,18 +510,37 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, int count, max_segcount = early_segcount; max_real_segsize = true_extent + (max_segcount - 1) * extent; - - inbuf[0] = (char*)malloc(max_real_segsize); - if (NULL == inbuf[0]) { ret = -1; line = __LINE__; goto error_hndl; } + /* get the device for sbuf and rbuf and where the op would like to execute */ + int sendbuf_dev, recvbuf_dev, op_dev; + uint64_t sendbuf_flags, recvbuf_flags; + ompi_coll_base_select_device(op, sbuf, rbuf, count, dtype, &sendbuf_dev, &recvbuf_dev, + &sendbuf_flags, &recvbuf_flags, &op_dev); if (size > 2) { - inbuf[1] = (char*)malloc(max_real_segsize); - if (NULL == inbuf[1]) { ret = -1; line = __LINE__; goto error_hndl; } + inbuf[0] = ompi_coll_base_allocate_on_device(op_dev, 2*max_real_segsize, module); + if (NULL == inbuf[0]) { ret = -1; line = __LINE__; goto error_hndl; } + inbuf[1] = inbuf[0] + max_real_segsize; + } else { + inbuf[0] = ompi_coll_base_allocate_on_device(op_dev, max_real_segsize, module); + if (NULL == inbuf[0]) { ret = -1; line = __LINE__; goto error_hndl; } } + //printf("allreduce ring count %d sbuf_dev %d rbuf_dev %d op_dev %d\n", count, sendbuf_dev, recvbuf_dev, op_dev); /* Handle MPI_IN_PLACE */ - if (MPI_IN_PLACE != sbuf) { - ret = ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, (char*)sbuf); - if (ret < 0) { line = __LINE__; goto error_hndl; } + bool use_sbuf = (MPI_IN_PLACE != sbuf); + /* allocate temporary recv buffer if the tmpbuf above is on a different device than the rbuf */ + recvbuf = rbuf; + if (op_dev != recvbuf_dev && + /* only copy if op is on the device or the recvbuffer cannot be accessed on the host */ + (op_dev != MCA_ACCELERATOR_NO_DEVICE_ID || 0 == (MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY & recvbuf_flags))) { + recvbuf = ompi_coll_base_allocate_on_device(op_dev, typelng*count, module); + if (use_sbuf) { + /* copy from rbuf */ + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)recvbuf, (char*)sbuf); + } else { + /* copy from sbuf */ + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)recvbuf, (char*)rbuf); + } + use_sbuf = false; } /* Computation loop */ @@ -443,7 +573,7 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, int count, ((ptrdiff_t)rank * (ptrdiff_t)early_segcount) : ((ptrdiff_t)rank * (ptrdiff_t)late_segcount + split_rank)); block_count = ((rank < split_rank)? early_segcount : late_segcount); - tmpsend = ((char*)rbuf) + block_offset * extent; + tmpsend = ((use_sbuf) ? ((char*)sbuf) : ((char*)recvbuf)) + block_offset * extent; ret = MCA_PML_CALL(send(tmpsend, block_count, dtype, send_to, MCA_COLL_BASE_TAG_ALLREDUCE, MCA_PML_BASE_SEND_STANDARD, comm)); @@ -465,13 +595,22 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, int count, /* Apply operation on previous block: result goes to rbuf rbuf[prevblock] = inbuf[inbi ^ 0x1] (op) rbuf[prevblock] - */ + */ block_offset = ((prevblock < split_rank)? ((ptrdiff_t)prevblock * early_segcount) : ((ptrdiff_t)prevblock * late_segcount + split_rank)); block_count = ((prevblock < split_rank)? early_segcount : late_segcount); - tmprecv = ((char*)rbuf) + (ptrdiff_t)block_offset * extent; - ompi_op_reduce(op, inbuf[inbi ^ 0x1], tmprecv, block_count, dtype); + tmprecv = ((char*)recvbuf) + (ptrdiff_t)block_offset * extent; + if (use_sbuf) { + void *tmpsbuf = ((char*)sbuf) + (ptrdiff_t)block_offset * extent; + /* tmprecv = inbuf[inbi ^ 0x1] (op) sbuf */ + ompi_3buff_op_reduce_stream(op, inbuf[inbi ^ 0x1], tmpsbuf, tmprecv, block_count, + dtype, op_dev, NULL); + } else { + /* tmprecv = inbuf[inbi ^ 0x1] (op) tmprecv */ + ompi_op_reduce_stream(op, inbuf[inbi ^ 0x1], tmprecv, block_count, + dtype, op_dev, NULL); + } /* send previous block to send_to */ ret = MCA_PML_CALL(send(tmprecv, block_count, dtype, send_to, @@ -491,8 +630,8 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, int count, ((ptrdiff_t)recv_from * early_segcount) : ((ptrdiff_t)recv_from * late_segcount + split_rank)); block_count = ((recv_from < split_rank)? early_segcount : late_segcount); - tmprecv = ((char*)rbuf) + (ptrdiff_t)block_offset * extent; - ompi_op_reduce(op, inbuf[inbi], tmprecv, block_count, dtype); + tmprecv = ((char*)recvbuf) + (ptrdiff_t)block_offset * extent; + ompi_op_reduce_stream(op, inbuf[inbi], tmprecv, block_count, dtype, op_dev, NULL); /* Distribution loop - variation of ring allgather */ send_to = (rank + 1) % size; @@ -511,8 +650,8 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, int count, block_count = ((send_data_from < split_rank)? early_segcount : late_segcount); - tmprecv = (char*)rbuf + (ptrdiff_t)recv_block_offset * extent; - tmpsend = (char*)rbuf + (ptrdiff_t)send_block_offset * extent; + tmprecv = (char*)recvbuf + (ptrdiff_t)recv_block_offset * extent; + tmpsend = (char*)recvbuf + (ptrdiff_t)send_block_offset * extent; ret = ompi_coll_base_sendrecv(tmpsend, block_count, dtype, send_to, MCA_COLL_BASE_TAG_ALLREDUCE, @@ -520,11 +659,14 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, int count, MCA_COLL_BASE_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE, rank); if (MPI_SUCCESS != ret) { line = __LINE__; goto error_hndl;} - } - if (NULL != inbuf[0]) free(inbuf[0]); - if (NULL != inbuf[1]) free(inbuf[1]); + ompi_coll_base_free_tmpbuf(inbuf[0], op_dev, module); + if (recvbuf != rbuf) { + /* copy to final rbuf and release temporary recvbuf */ + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, (char*)recvbuf); + ompi_coll_base_free_tmpbuf(recvbuf, op_dev, module); + } return MPI_SUCCESS; @@ -533,8 +675,12 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, int count, __FILE__, line, rank, ret)); ompi_coll_base_free_reqs(reqs, 2); (void)line; // silence compiler warning - if (NULL != inbuf[0]) free(inbuf[0]); - if (NULL != inbuf[1]) free(inbuf[1]); + ompi_coll_base_free_tmpbuf(inbuf[0], op_dev, module); + if (NULL != recvbuf && recvbuf != rbuf) { + /* copy to final rbuf and release temporary recvbuf */ + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, (char*)recvbuf); + ompi_coll_base_free_tmpbuf(recvbuf, op_dev, module); + } return ret; } @@ -687,16 +833,21 @@ ompi_coll_base_allreduce_intra_ring_segmented(const void *sbuf, void *rbuf, int if (MPI_SUCCESS != ret) { line = __LINE__; goto error_hndl; } max_real_segsize = opal_datatype_span(&dtype->super, max_segcount, &gap); + int sendbuf_dev, recvbuf_dev, op_dev; + uint64_t sendbuf_flags, recvbuf_flags; + ompi_coll_base_select_device(op, sbuf, rbuf, count, dtype, &sendbuf_dev, &recvbuf_dev, + &sendbuf_flags, &recvbuf_flags, &op_dev); /* Allocate and initialize temporary buffers */ - inbuf[0] = (char*)malloc(max_real_segsize); + inbuf[0] = ompi_coll_base_allocate_on_device(op_dev, max_real_segsize, module); if (NULL == inbuf[0]) { ret = -1; line = __LINE__; goto error_hndl; } if (size > 2) { - inbuf[1] = (char*)malloc(max_real_segsize); + inbuf[1] = ompi_coll_base_allocate_on_device(op_dev, max_real_segsize, module); if (NULL == inbuf[1]) { ret = -1; line = __LINE__; goto error_hndl; } } /* Handle MPI_IN_PLACE */ if (MPI_IN_PLACE != sbuf) { + /* TODO: can we avoid this copy? */ ret = ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, (char*)sbuf); if (ret < 0) { line = __LINE__; goto error_hndl; } } @@ -782,7 +933,8 @@ ompi_coll_base_allreduce_intra_ring_segmented(const void *sbuf, void *rbuf, int ((ptrdiff_t)phase * (ptrdiff_t)early_phase_segcount) : ((ptrdiff_t)phase * (ptrdiff_t)late_phase_segcount + split_phase)); tmprecv = ((char*)rbuf) + (ptrdiff_t)(block_offset + phase_offset) * extent; - ompi_op_reduce(op, inbuf[inbi ^ 0x1], tmprecv, phase_count, dtype); + ompi_op_reduce_stream(op, inbuf[inbi ^ 0x1], tmprecv, phase_count, + dtype, op_dev, NULL); /* send previous block to send_to */ ret = MCA_PML_CALL(send(tmprecv, phase_count, dtype, send_to, @@ -811,7 +963,8 @@ ompi_coll_base_allreduce_intra_ring_segmented(const void *sbuf, void *rbuf, int ((ptrdiff_t)phase * (ptrdiff_t)early_phase_segcount) : ((ptrdiff_t)phase * (ptrdiff_t)late_phase_segcount + split_phase)); tmprecv = ((char*)rbuf) + (ptrdiff_t)(block_offset + phase_offset) * extent; - ompi_op_reduce(op, inbuf[inbi], tmprecv, phase_count, dtype); + ompi_op_reduce_stream(op, inbuf[inbi], tmprecv, phase_count, + dtype, op_dev, NULL); } /* Distribution loop - variation of ring allgather */ @@ -843,8 +996,8 @@ ompi_coll_base_allreduce_intra_ring_segmented(const void *sbuf, void *rbuf, int } - if (NULL != inbuf[0]) free(inbuf[0]); - if (NULL != inbuf[1]) free(inbuf[1]); + ompi_coll_base_free_tmpbuf(inbuf[0], op_dev, module); + ompi_coll_base_free_tmpbuf(inbuf[1], op_dev, module); return MPI_SUCCESS; @@ -853,8 +1006,8 @@ ompi_coll_base_allreduce_intra_ring_segmented(const void *sbuf, void *rbuf, int __FILE__, line, rank, ret)); ompi_coll_base_free_reqs(reqs, 2); (void)line; // silence compiler warning - if (NULL != inbuf[0]) free(inbuf[0]); - if (NULL != inbuf[1]) free(inbuf[1]); + ompi_coll_base_free_tmpbuf(inbuf[0], op_dev, module); + ompi_coll_base_free_tmpbuf(inbuf[1], op_dev, module); return ret; } @@ -1004,19 +1157,31 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( ompi_datatype_get_extent(dtype, &lb, &extent); dsize = opal_datatype_span(&dtype->super, count, &gap); + /* get the device for sbuf and rbuf and where the op would like to execute */ + int sendbuf_dev, recvbuf_dev, op_dev; + uint64_t sendbuf_flags, recvbuf_flags; + ompi_coll_base_select_device(op, sbuf, rbuf, count, dtype, &sendbuf_dev, &recvbuf_dev, + &sendbuf_flags, &recvbuf_flags, &op_dev); + /* Temporary buffer for receiving messages */ char *tmp_buf = NULL; - char *tmp_buf_raw = (char *)malloc(dsize); + char *tmp_buf_raw = ompi_coll_base_allocate_on_device(op_dev, dsize, module); if (NULL == tmp_buf_raw) return OMPI_ERR_OUT_OF_RESOURCE; tmp_buf = tmp_buf_raw - gap; - if (sbuf != MPI_IN_PLACE) { - err = ompi_datatype_copy_content_same_ddt(dtype, count, (char *)rbuf, - (char *)sbuf); - if (MPI_SUCCESS != err) { goto cleanup_and_return; } + char *recvbuf = rbuf; + if (op_dev != recvbuf_dev && 0 == (MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY & recvbuf_flags)) { + recvbuf = ompi_coll_base_allocate_on_device(op_dev, dsize, module); + } + if (op_dev != sendbuf_dev && 0 == (MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY & sendbuf_flags) && sbuf != MPI_IN_PLACE) { + /* move the data into the recvbuf and set sbuf to MPI_IN_PLACE */ + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)recvbuf, (char*)sbuf); + sbuf = MPI_IN_PLACE; } + //printf("redscat: count %d sbuf %p dev %d recvbuf %p dev %d tmp_buf %p dev %d\n", count, sbuf, sendbuf_dev, recvbuf, recvbuf_dev, tmp_buf_raw, op_dev); + /* * Step 1. Reduce the number of processes to the nearest lower power of two * p' = 2^{\floor{\log_2 p}} by removing r = p - p' processes. @@ -1037,9 +1202,18 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( int vrank, step, wsize; int nprocs_rem = comm_size - nprocs_pof2; + opal_accelerator_stream_t *stream = NULL; + if (op_dev >= 0) { + opal_accelerator.get_default_stream(op_dev, &stream); + } + if (rank < 2 * nprocs_rem) { int count_lhalf = count / 2; int count_rhalf = count - count_lhalf; + const void *send_buf = sbuf; + if (MPI_IN_PLACE == sbuf) { + send_buf = recvbuf; + } if (rank % 2 != 0) { /* @@ -1047,7 +1221,7 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( * Send the left half of the input vector to the left neighbor, * Recv the right half of the input vector from the left neighbor */ - err = ompi_coll_base_sendrecv(rbuf, count_lhalf, dtype, rank - 1, + err = ompi_coll_base_sendrecv((void*)send_buf, count_lhalf, dtype, rank - 1, MCA_COLL_BASE_TAG_ALLREDUCE, (char *)tmp_buf + (ptrdiff_t)count_lhalf * extent, count_rhalf, dtype, rank - 1, @@ -1055,12 +1229,24 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( MPI_STATUS_IGNORE, rank); if (MPI_SUCCESS != err) { goto cleanup_and_return; } - /* Reduce on the right half of the buffers (result in rbuf) */ - ompi_op_reduce(op, (char *)tmp_buf + (ptrdiff_t)count_lhalf * extent, - (char *)rbuf + count_lhalf * extent, count_rhalf, dtype); + /* Reduce on the right half of the buffers (result in rbuf) + * We're not using a stream here, the reduction will make sure that the result is available upon return */ + if (MPI_IN_PLACE != sbuf) { + /* rbuf = sbuf (op) tmp_buf */ + ompi_3buff_op_reduce_stream(op, + (char *)tmp_buf + (ptrdiff_t)count_lhalf * extent, + (char *)sbuf + (ptrdiff_t)count_lhalf * extent, + (char *)recvbuf + count_lhalf * extent, + count_rhalf, dtype, op_dev, NULL); + } else { + /* rbuf = rbuf (op) tmp_buf */ + ompi_op_reduce_stream(op, (char *)tmp_buf + (ptrdiff_t)count_lhalf * extent, + (char *)recvbuf + count_lhalf * extent, count_rhalf, + dtype, op_dev, NULL); + } /* Send the right half to the left neighbor */ - err = MCA_PML_CALL(send((char *)rbuf + (ptrdiff_t)count_lhalf * extent, + err = MCA_PML_CALL(send((char *)recvbuf + (ptrdiff_t)count_lhalf * extent, count_rhalf, dtype, rank - 1, MCA_COLL_BASE_TAG_ALLREDUCE, MCA_PML_BASE_SEND_STANDARD, comm)); @@ -1075,7 +1261,7 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( * Send the right half of the input vector to the right neighbor, * Recv the left half of the input vector from the right neighbor */ - err = ompi_coll_base_sendrecv((char *)rbuf + (ptrdiff_t)count_lhalf * extent, + err = ompi_coll_base_sendrecv((char *)send_buf + (ptrdiff_t)count_lhalf * extent, count_rhalf, dtype, rank + 1, MCA_COLL_BASE_TAG_ALLREDUCE, tmp_buf, count_lhalf, dtype, rank + 1, @@ -1084,21 +1270,35 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( if (MPI_SUCCESS != err) { goto cleanup_and_return; } /* Reduce on the right half of the buffers (result in rbuf) */ - ompi_op_reduce(op, tmp_buf, rbuf, count_lhalf, dtype); + if (MPI_IN_PLACE != sbuf) { + /* rbuf = sbuf (op) tmp_buf */ + ompi_3buff_op_reduce_stream(op, sbuf, tmp_buf, recvbuf, count_lhalf, dtype, op_dev, stream); + } else { + /* rbuf = rbuf (op) tmp_buf */ + ompi_op_reduce_stream(op, tmp_buf, recvbuf, count_lhalf, dtype, op_dev, stream); + } + /* Recv the right half from the right neighbor */ - err = MCA_PML_CALL(recv((char *)rbuf + (ptrdiff_t)count_lhalf * extent, + err = MCA_PML_CALL(recv((char *)recvbuf + (ptrdiff_t)count_lhalf * extent, count_rhalf, dtype, rank + 1, MCA_COLL_BASE_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE)); if (MPI_SUCCESS != err) { goto cleanup_and_return; } + /* wait for the op to complete */ + if (NULL != stream) { + opal_accelerator.wait_stream(stream); + } + vrank = rank / 2; } } else { /* rank >= 2 * nprocs_rem */ vrank = rank - nprocs_rem; } + /* At this point the input data has been accumulated into the rbuf */ + /* * Step 2. Reduce-scatter implemented with recursive vector halving and * recursive distance doubling. We have p' = 2^{\floor{\log_2 p}} @@ -1155,7 +1355,7 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( } /* Send part of data from the rbuf, recv into the tmp_buf */ - err = ompi_coll_base_sendrecv((char *)rbuf + (ptrdiff_t)sindex[step] * extent, + err = ompi_coll_base_sendrecv((char *)recvbuf + (ptrdiff_t)sindex[step] * extent, scount[step], dtype, dest, MCA_COLL_BASE_TAG_ALLREDUCE, (char *)tmp_buf + (ptrdiff_t)rindex[step] * extent, @@ -1165,9 +1365,9 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( if (MPI_SUCCESS != err) { goto cleanup_and_return; } /* Local reduce: rbuf[] = tmp_buf[] rbuf[] */ - ompi_op_reduce(op, (char *)tmp_buf + (ptrdiff_t)rindex[step] * extent, - (char *)rbuf + (ptrdiff_t)rindex[step] * extent, - rcount[step], dtype); + ompi_op_reduce_stream(op, (char *)tmp_buf + (ptrdiff_t)rindex[step] * extent, + (char *)recvbuf + (ptrdiff_t)rindex[step] * extent, + rcount[step], dtype, op_dev, NULL); /* Move the current window to the received message */ if (step + 1 < nsteps) { @@ -1201,10 +1401,10 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( * Send rcount[step] elements from rbuf[rindex[step]...] * Recv scount[step] elements to rbuf[sindex[step]...] */ - err = ompi_coll_base_sendrecv((char *)rbuf + (ptrdiff_t)rindex[step] * extent, + err = ompi_coll_base_sendrecv((char *)recvbuf + (ptrdiff_t)rindex[step] * extent, rcount[step], dtype, dest, MCA_COLL_BASE_TAG_ALLREDUCE, - (char *)rbuf + (ptrdiff_t)sindex[step] * extent, + (char *)recvbuf + (ptrdiff_t)sindex[step] * extent, scount[step], dtype, dest, MCA_COLL_BASE_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE, rank); @@ -1216,6 +1416,7 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( /* * Step 4. Send total result to excluded odd ranks. */ + bool recvbuf_need_copy = true; if (rank < 2 * nprocs_rem) { if (rank % 2 != 0) { /* Odd process -- recv result from rank - 1 */ @@ -1223,19 +1424,28 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( MCA_COLL_BASE_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE)); if (OMPI_SUCCESS != err) { goto cleanup_and_return; } + recvbuf_need_copy = false; } else { /* Even process -- send result to rank + 1 */ - err = MCA_PML_CALL(send(rbuf, count, dtype, rank + 1, + err = MCA_PML_CALL(send(recvbuf, count, dtype, rank + 1, MCA_COLL_BASE_TAG_ALLREDUCE, MCA_PML_BASE_SEND_STANDARD, comm)); if (MPI_SUCCESS != err) { goto cleanup_and_return; } } } + if (recvbuf != rbuf) { + /* copy into final rbuf */ + if (recvbuf_need_copy) { + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, (char*)recvbuf); + } + ompi_coll_base_free_tmpbuf(recvbuf, op_dev, module); + } + cleanup_and_return: - if (NULL != tmp_buf_raw) - free(tmp_buf_raw); + + ompi_coll_base_free_tmpbuf(tmp_buf_raw, op_dev, module); if (NULL != rindex) free(rindex); if (NULL != sindex) diff --git a/ompi/mca/coll/base/coll_base_frame.c b/ompi/mca/coll/base/coll_base_frame.c index 5bb6fe38ace..45b7fcaeb67 100644 --- a/ompi/mca/coll/base/coll_base_frame.c +++ b/ompi/mca/coll/base/coll_base_frame.c @@ -30,6 +30,7 @@ #include "ompi/mca/mca.h" #include "opal/util/output.h" #include "opal/mca/base/base.h" +#include "opal/mca/accelerator/accelerator.h" #include "ompi/mca/coll/coll.h" @@ -70,6 +71,8 @@ static void coll_base_comm_construct(mca_coll_base_comm_t *data) { memset ((char *) data + sizeof (data->super), 0, sizeof (*data) - sizeof (data->super)); + data->device_allocators = NULL; + data->num_device_allocators = 0; } static void @@ -108,6 +111,16 @@ coll_base_comm_destruct(mca_coll_base_comm_t *data) if (data->cached_in_order_bintree) { /* destroy in order bintree if defined */ ompi_coll_base_topo_destroy_tree (&data->cached_in_order_bintree); } + + if (NULL != data->device_allocators) { + for (int i = 0; i < data->num_device_allocators; ++i) { + if (NULL != data->device_allocators[i]) { + data->device_allocators[i]->alc_finalize(data->device_allocators[i]); + } + } + free(data->device_allocators); + data->device_allocators = NULL; + } } OBJ_CLASS_INSTANCE(mca_coll_base_comm_t, opal_object_t, diff --git a/ompi/mca/coll/base/coll_base_functions.h b/ompi/mca/coll/base/coll_base_functions.h index 1c73d01d37e..b3657c2f7f4 100644 --- a/ompi/mca/coll/base/coll_base_functions.h +++ b/ompi/mca/coll/base/coll_base_functions.h @@ -40,6 +40,8 @@ /* need to include our own topo prototypes so we can malloc data on the comm correctly */ #include "coll_base_topo.h" +#include "opal/mca/allocator/allocator.h" + /* some fixed value index vars to simplify certain operations */ typedef enum COLLTYPE { ALLGATHER = 0, /* 0 */ @@ -514,6 +516,10 @@ struct mca_coll_base_comm_t { /* in-order binary tree (root of the in-order binary tree is rank 0) */ ompi_coll_tree_t *cached_in_order_bintree; + + /* pointer to per-device memory cache */ + mca_allocator_base_module_t **device_allocators; + int num_device_allocators; }; typedef struct mca_coll_base_comm_t mca_coll_base_comm_t; OMPI_DECLSPEC OBJ_CLASS_DECLARATION(mca_coll_base_comm_t); diff --git a/ompi/mca/coll/base/coll_base_reduce.c b/ompi/mca/coll/base/coll_base_reduce.c index e7cf7e0656a..72efde701f7 100644 --- a/ompi/mca/coll/base/coll_base_reduce.c +++ b/ompi/mca/coll/base/coll_base_reduce.c @@ -114,7 +114,7 @@ int ompi_coll_base_reduce_generic( const void* sendbuf, void* recvbuf, int origi /* If this is a non-commutative operation we must copy sendbuf to the accumbuf, in order to simplify the loops */ - + if (!ompi_op_is_commute(op) && MPI_IN_PLACE != sendbuf) { ompi_datatype_copy_content_same_ddt(datatype, original_count, (char*)accumbuf, diff --git a/ompi/mca/coll/base/coll_base_util.c b/ompi/mca/coll/base/coll_base_util.c index 9465ef95c62..9479c68f7cc 100644 --- a/ompi/mca/coll/base/coll_base_util.c +++ b/ompi/mca/coll/base/coll_base_util.c @@ -30,6 +30,7 @@ #include "ompi/mca/pml/pml.h" #include "coll_base_util.h" #include "coll_base_functions.h" +#include "opal/mca/allocator/base/base.h" #include int ompi_coll_base_sendrecv_actual( const void* sendbuf, size_t scount, @@ -602,3 +603,58 @@ const char* mca_coll_base_colltype_to_str(int collid) } return colltype_translation_table[collid]; } + +static void* ompi_coll_base_device_allocate_cb(void *ctx, size_t *size) { + int dev_id = (intptr_t)ctx; + void *ptr = NULL; + opal_accelerator.mem_alloc(dev_id, &ptr, *size); + return ptr; +} + +static void ompi_coll_base_device_release_cb(void *ctx, void* ptr) { + int dev_id = (intptr_t)ctx; + opal_accelerator.mem_release(dev_id, ptr); +} + +void *ompi_coll_base_allocate_on_device(int device, size_t size, + mca_coll_base_module_t *module) +{ + mca_allocator_base_module_t *allocator_module; + if (device < 0) { + return malloc(size); + } + + if (module->base_data->num_device_allocators <= device) { + int num_dev; + opal_accelerator.num_devices(&num_dev); + if (num_dev < device+1) num_dev = device+1; + module->base_data->device_allocators = realloc(module->base_data->device_allocators, num_dev * sizeof(mca_allocator_base_module_t *)); + for (int i = module->base_data->num_device_allocators; i < num_dev; ++i) { + module->base_data->device_allocators[i] = NULL; + } + module->base_data->num_device_allocators = num_dev; + } + if (NULL == (allocator_module = module->base_data->device_allocators[device])) { + mca_allocator_base_component_t *allocator_component; + allocator_component = mca_allocator_component_lookup("devicebucket"); + assert(allocator_component != NULL); + allocator_module = allocator_component->allocator_init(false, ompi_coll_base_device_allocate_cb, + ompi_coll_base_device_release_cb, + (void*)(intptr_t)device); + assert(allocator_module != NULL); + module->base_data->device_allocators[device] = allocator_module; + } + return allocator_module->alc_alloc(allocator_module, size, 0); +} + +void ompi_coll_base_free_on_device(int device, void *ptr, mca_coll_base_module_t *module) +{ + mca_allocator_base_module_t *allocator_module; + if (device < 0) { + free(ptr); + } else { + assert(NULL != module->base_data->device_allocators); + allocator_module = module->base_data->device_allocators[device]; + allocator_module->alc_free(allocator_module, ptr); + } +} diff --git a/ompi/mca/coll/base/coll_base_util.h b/ompi/mca/coll/base/coll_base_util.h index 6982c0fb4f3..9dd1a6f41d0 100644 --- a/ompi/mca/coll/base/coll_base_util.h +++ b/ompi/mca/coll/base/coll_base_util.h @@ -31,6 +31,7 @@ #include "ompi/mca/coll/base/coll_tags.h" #include "ompi/op/op.h" #include "ompi/mca/pml/pml.h" +#include "opal/mca/accelerator/accelerator.h" BEGIN_C_DECLS @@ -203,5 +204,48 @@ int ompi_coll_base_file_peek_next_char_is(FILE *fptr, int *fileline, int expecte const char* mca_coll_base_colltype_to_str(int collid); int mca_coll_base_name_to_colltype(const char* name); +/* device/host memory allocation functions */ + + +void *ompi_coll_base_allocate_on_device(int device, size_t size, + mca_coll_base_module_t *module); + +void ompi_coll_base_free_on_device(int device, void *ptr, mca_coll_base_module_t *module); + + +static inline +void ompi_coll_base_select_device( + struct ompi_op_t *op, + const void *sendbuf, + const void *recvbuf, + size_t count, + struct ompi_datatype_t *dtype, + int *sendbuf_device, + int *recvbuf_device, + uint64_t *sendbuf_flags, + uint64_t *recvbuf_flags, + int *op_device) +{ + *recvbuf_device = MCA_ACCELERATOR_NO_DEVICE_ID; + *sendbuf_device = MCA_ACCELERATOR_NO_DEVICE_ID; + if (sendbuf != NULL && sendbuf != MPI_IN_PLACE) opal_accelerator.check_addr(sendbuf, sendbuf_device, sendbuf_flags); + if (recvbuf != NULL) opal_accelerator.check_addr(recvbuf, recvbuf_device, recvbuf_flags); + ompi_op_preferred_device(op, *recvbuf_device, *sendbuf_device, count, dtype, op_device); +} + +/** + * Frees memory allocated through ompi_coll_base_allocate_op_tmpbuf + * or ompi_coll_base_allocate_tmpbuf. + */ +static inline +void ompi_coll_base_free_tmpbuf(void *tmpbuf, int device, mca_coll_base_module_t *module) { + if (-1 == device) { + free(tmpbuf); + } else if (NULL != tmpbuf) { + ompi_coll_base_free_on_device(device, tmpbuf, module); + } +} + + END_C_DECLS #endif /* MCA_COLL_BASE_UTIL_EXPORT_H */ diff --git a/ompi/mca/coll/basic/coll_basic_allreduce.c b/ompi/mca/coll/basic/coll_basic_allreduce.c index bc855726208..065d358a4af 100644 --- a/ompi/mca/coll/basic/coll_basic_allreduce.c +++ b/ompi/mca/coll/basic/coll_basic_allreduce.c @@ -2,7 +2,7 @@ * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana * University Research and Technology * Corporation. All rights reserved. - * Copyright (c) 2004-2017 The University of Tennessee and The University + * Copyright (c) 2004-2023 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart, @@ -32,6 +32,8 @@ #include "coll_basic.h" #include "ompi/mca/pml/pml.h" +#include "opal/mca/accelerator/accelerator.h" + /* * allreduce_intra @@ -82,10 +84,11 @@ mca_coll_basic_allreduce_inter(const void *sbuf, void *rbuf, int count, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { - int err, i, rank, root = 0, rsize, line; + int err, i, rank, root = 0, rsize, line, rbuf_dev; ptrdiff_t extent, dsize, gap; char *tmpbuf = NULL, *pml_buffer = NULL; ompi_request_t **reqs = NULL; + bool rbuf_on_device = false; rank = ompi_comm_rank(comm); rsize = ompi_comm_remote_size(comm); @@ -105,8 +108,15 @@ mca_coll_basic_allreduce_inter(const void *sbuf, void *rbuf, int count, return OMPI_ERROR; } dsize = opal_datatype_span(&dtype->super, count, &gap); - tmpbuf = (char *) malloc(dsize); - if (NULL == tmpbuf) { err = OMPI_ERR_OUT_OF_RESOURCE; line = __LINE__; goto exit; } + if (opal_accelerator.check_addr(rbuf, &rbuf_dev, NULL) > 0 && rbuf_dev >= 0) { + if (OPAL_SUCCESS != opal_accelerator.mem_alloc(rbuf_dev, (void**)&tmpbuf, dsize)) { + err = OMPI_ERR_OUT_OF_RESOURCE; line = __LINE__; goto exit; + } + rbuf_on_device = true; + } else { + tmpbuf = (char *) malloc(dsize); + if (NULL == tmpbuf) { err = OMPI_ERR_OUT_OF_RESOURCE; line = __LINE__; goto exit; } + } pml_buffer = tmpbuf - gap; if (rsize > 1) { @@ -188,7 +198,9 @@ mca_coll_basic_allreduce_inter(const void *sbuf, void *rbuf, int count, (void)line; // silence compiler warning ompi_coll_base_free_reqs(reqs, rsize - 1); } - if (NULL != tmpbuf) { + if (rbuf_on_device) { + opal_accelerator.mem_release(rbuf_dev, tmpbuf); + } else { free(tmpbuf); } diff --git a/ompi/mca/op/base/op_base_frame.c b/ompi/mca/op/base/op_base_frame.c index 90167300851..1a7d6dc1320 100644 --- a/ompi/mca/op/base/op_base_frame.c +++ b/ompi/mca/op/base/op_base_frame.c @@ -2,7 +2,7 @@ * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana * University Research and Technology * Corporation. All rights reserved. - * Copyright (c) 2004-2005 The University of Tennessee and The University + * Copyright (c) 2004-2023 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart, @@ -42,6 +42,7 @@ static void module_constructor(ompi_op_base_module_t *m) { m->opm_enable = NULL; m->opm_op = NULL; + m->opm_device_enabled = false; memset(&(m->opm_fns), 0, sizeof(m->opm_fns)); memset(&(m->opm_3buff_fns), 0, sizeof(m->opm_3buff_fns)); } @@ -50,6 +51,7 @@ static void module_constructor_1_0_0(ompi_op_base_module_1_0_0_t *m) { m->opm_enable = NULL; m->opm_op = NULL; + m->opm_device_enabled = false; memset(&(m->opm_fns), 0, sizeof(m->opm_fns)); memset(&(m->opm_3buff_fns), 0, sizeof(m->opm_3buff_fns)); } diff --git a/ompi/mca/op/base/op_base_op_select.c b/ompi/mca/op/base/op_base_op_select.c index 53754ce5668..534a1d63267 100644 --- a/ompi/mca/op/base/op_base_op_select.c +++ b/ompi/mca/op/base/op_base_op_select.c @@ -3,7 +3,7 @@ * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana * University Research and Technology * Corporation. All rights reserved. - * Copyright (c) 2004-2009 The University of Tennessee and The University + * Copyright (c) 2004-2023 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart, @@ -152,22 +152,50 @@ int ompi_op_base_op_select(ompi_op_t *op) } /* Copy over the non-NULL pointers */ - for (i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { - /* 2-buffer variants */ - if (NULL != avail->ao_module->opm_fns[i]) { - OBJ_RELEASE(op->o_func.intrinsic.modules[i]); - op->o_func.intrinsic.fns[i] = avail->ao_module->opm_fns[i]; - op->o_func.intrinsic.modules[i] = avail->ao_module; - OBJ_RETAIN(avail->ao_module); + if (avail->ao_module->opm_device_enabled) { + if (NULL == op->o_device_op) { + op->o_device_op = calloc(1, sizeof(*op->o_device_op)); } - - /* 3-buffer variants */ - if (NULL != avail->ao_module->opm_3buff_fns[i]) { - OBJ_RELEASE(op->o_func.intrinsic.modules[i]); - op->o_3buff_intrinsic.fns[i] = - avail->ao_module->opm_3buff_fns[i]; - op->o_3buff_intrinsic.modules[i] = avail->ao_module; - OBJ_RETAIN(avail->ao_module); + for (i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { + /* 2-buffer variants */ + if (NULL != avail->ao_module->opm_stream_fns[i]) { + if (NULL != op->o_device_op->do_intrinsic.modules[i]) { + OBJ_RELEASE(op->o_device_op->do_intrinsic.modules[i]); + } + op->o_device_op->do_intrinsic.fns[i] = avail->ao_module->opm_stream_fns[i]; + op->o_device_op->do_intrinsic.modules[i] = avail->ao_module; + OBJ_RETAIN(avail->ao_module); + } + + /* 3-buffer variants */ + if (NULL != avail->ao_module->opm_3buff_stream_fns[i]) { + if (NULL != op->o_device_op->do_3buff_intrinsic.modules[i]) { + OBJ_RELEASE(op->o_device_op->do_3buff_intrinsic.modules[i]); + } + op->o_device_op->do_3buff_intrinsic.fns[i] = + avail->ao_module->opm_3buff_stream_fns[i]; + op->o_device_op->do_3buff_intrinsic.modules[i] = avail->ao_module; + OBJ_RETAIN(avail->ao_module); + } + } + } else { + for (i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { + /* 2-buffer variants */ + if (NULL != avail->ao_module->opm_fns[i]) { + OBJ_RELEASE(op->o_func.intrinsic.modules[i]); + op->o_func.intrinsic.fns[i] = avail->ao_module->opm_fns[i]; + op->o_func.intrinsic.modules[i] = avail->ao_module; + OBJ_RETAIN(avail->ao_module); + } + + /* 3-buffer variants */ + if (NULL != avail->ao_module->opm_3buff_fns[i]) { + OBJ_RELEASE(op->o_3buff_intrinsic.modules[i]); + op->o_3buff_intrinsic.fns[i] = + avail->ao_module->opm_3buff_fns[i]; + op->o_3buff_intrinsic.modules[i] = avail->ao_module; + OBJ_RETAIN(avail->ao_module); + } } } diff --git a/ompi/mca/op/cuda/Makefile.am b/ompi/mca/op/cuda/Makefile.am new file mode 100644 index 00000000000..5e68ddf5854 --- /dev/null +++ b/ompi/mca/op/cuda/Makefile.am @@ -0,0 +1,84 @@ +# +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +# This component provides support for offloading reduce ops to CUDA devices. +# +# See https://github.com/open-mpi/ompi/wiki/devel-CreateComponent +# for more details on how to make Open MPI components. + +# First, list all .h and .c sources. It is necessary to list all .h +# files so that they will be picked up in the distribution tarball. + +AM_CPPFLAGS = $(op_cuda_CPPFLAGS) $(op_cudart_CPPFLAGS) + +dist_ompidata_DATA = help-ompi-mca-op-cuda.txt + +sources = op_cuda_component.c op_cuda.h op_cuda_functions.c op_cuda_impl.h +#sources_extended = op_cuda_functions.cu +cu_sources = op_cuda_impl.cu + +NVCC = nvcc -g +NVCCFLAGS= --std c++17 --gpu-architecture=compute_52 + +.cu.l$(OBJEXT): + $(LIBTOOL) $(AM_V_lt) --tag=CC $(AM_LIBTOOLFLAGS) \ + $(LIBTOOLFLAGS) --mode=compile $(NVCC) -prefer-non-pic $(NVCCFLAGS) -Wc,-Xcompiler,-fPIC,-g -c $< + +# -o $($@.o:.lo) + +# Open MPI components can be compiled two ways: +# +# 1. As a standalone dynamic shared object (DSO), sometimes called a +# dynamically loadable library (DLL). +# +# 2. As a static library that is slurped up into the upper-level +# libmpi library (regardless of whether libmpi is a static or dynamic +# library). This is called a "Libtool convenience library". +# +# The component needs to create an output library in this top-level +# component directory, and named either mca__.la (for DSO +# builds) or libmca__.la (for static builds). The OMPI +# build system will have set the +# MCA_BUILD_ompi___DSO AM_CONDITIONAL to indicate +# which way this component should be built. + +if MCA_BUILD_ompi_op_cuda_DSO +component_install = mca_op_cuda.la +else +component_install = +component_noinst = libmca_op_cuda.la +endif + +# Specific information for DSO builds. +# +# The DSO should install itself in $(ompilibdir) (by default, +# $prefix/lib/openmpi). + +mcacomponentdir = $(ompilibdir) +mcacomponent_LTLIBRARIES = $(component_install) +mca_op_cuda_la_SOURCES = $(sources) +mca_op_cuda_la_LIBADD = $(cu_sources:.cu=.lo) +mca_op_cuda_la_LDFLAGS = -module -avoid-version $(top_builddir)/ompi/lib@OMPI_LIBMPI_NAME@.la \ + $(op_cuda_LIBS) $(op_cudart_LIBS) +EXTRA_mca_op_cuda_la_SOURCES = $(cu_sources) + +# Specific information for static builds. +# +# Note that we *must* "noinst"; the upper-layer Makefile.am's will +# slurp in the resulting .la library into libmpi. + +noinst_LTLIBRARIES = $(component_noinst) +libmca_op_cuda_la_SOURCES = $(sources) +libmca_op_cuda_la_LIBADD = $(cu_sources:.cu=.lo) +libmca_op_cuda_la_LDFLAGS = -module -avoid-version\ + $(op_cuda_LIBS) $(op_cudart_LIBS) +EXTRA_libmca_op_cuda_la_SOURCES = $(cu_sources) + diff --git a/ompi/mca/op/cuda/configure.m4 b/ompi/mca/op/cuda/configure.m4 new file mode 100644 index 00000000000..0974e3aaf31 --- /dev/null +++ b/ompi/mca/op/cuda/configure.m4 @@ -0,0 +1,41 @@ +# -*- shell-script -*- +# +# Copyright (c) 2011-2013 NVIDIA Corporation. All rights reserved. +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# Copyright (c) 2022 Amazon.com, Inc. or its affiliates. +# All Rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +# +# If CUDA support was requested, then build the CUDA support library. +# This code checks makes sure the check was done earlier by the +# opal_check_cuda.m4 code. It also copies the flags and libs under +# opal_cuda_CPPFLAGS, opal_cuda_LDFLAGS, and opal_cuda_LIBS + +AC_DEFUN([MCA_ompi_op_cuda_CONFIG],[ + + AC_CONFIG_FILES([ompi/mca/op/cuda/Makefile]) + + OPAL_CHECK_CUDA([op_cuda]) + OPAL_CHECK_CUDART([op_cudart]) + + AS_IF([test "x$CUDA_SUPPORT" = "x1"], + [$1], + [$2]) + + AC_SUBST([op_cuda_CPPFLAGS]) + AC_SUBST([op_cuda_LDFLAGS]) + AC_SUBST([op_cuda_LIBS]) + + AC_SUBST([op_cudart_CPPFLAGS]) + AC_SUBST([op_cudart_LDFLAGS]) + AC_SUBST([op_cudart_LIBS]) + +])dnl diff --git a/ompi/mca/op/cuda/help-ompi-mca-op-cuda.txt b/ompi/mca/op/cuda/help-ompi-mca-op-cuda.txt new file mode 100644 index 00000000000..f999ebc939c --- /dev/null +++ b/ompi/mca/op/cuda/help-ompi-mca-op-cuda.txt @@ -0,0 +1,15 @@ +# -*- text -*- +# +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# +# This is the US/English help file for Open MPI's CUDA operator component +# +[CUDA call failed] +"CUDA call %s failed: %s: %s\n" diff --git a/ompi/mca/op/cuda/op_cuda.h b/ompi/mca/op/cuda/op_cuda.h new file mode 100644 index 00000000000..ab349d48ee4 --- /dev/null +++ b/ompi/mca/op/cuda/op_cuda.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#ifndef MCA_OP_CUDA_EXPORT_H +#define MCA_OP_CUDA_EXPORT_H + +#include "ompi_config.h" + +#include "ompi/mca/mca.h" +#include "opal/class/opal_object.h" + +#include "ompi/mca/op/op.h" +#include "ompi/runtime/mpiruntime.h" + +#include +#include + +BEGIN_C_DECLS + + +#define xstr(x) #x +#define str(x) xstr(x) + +#define CHECK(fn, args) \ + do { \ + cudaError_t err = fn args; \ + if (err != cudaSuccess) { \ + opal_show_help("help-ompi-mca-op-cuda.txt", \ + "CUDA call failed", true, \ + str(fn), cudaGetErrorName(err), \ + cudaGetErrorString(err)); \ + ompi_mpi_abort(MPI_COMM_WORLD, 1); \ + } \ + } while (0) + + +/** + * Derive a struct from the base op component struct, allowing us to + * cache some component-specific information on our well-known + * component struct. + */ +typedef struct { + /** The base op component struct */ + ompi_op_base_component_1_0_0_t super; + int cu_max_num_blocks; + int cu_max_num_threads; + int *cu_max_threads_per_block; + int *cu_max_blocks; + CUdevice *cu_devices; + int cu_num_devices; +} ompi_op_cuda_component_t; + +/** + * Globally exported variable. Note that it is a *cuda* component + * (defined above), which has the ompi_op_base_component_t as its + * first member. Hence, the MCA/op framework will find the data that + * it expects in the first memory locations, but then the component + * itself can cache additional information after that that can be used + * by both the component and modules. + */ +OMPI_DECLSPEC extern ompi_op_cuda_component_t + mca_op_cuda_component; + +OMPI_DECLSPEC extern +ompi_op_base_stream_handler_fn_t ompi_op_cuda_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX]; + +OMPI_DECLSPEC extern +ompi_op_base_3buff_stream_handler_fn_t ompi_op_cuda_3buff_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX]; + +END_C_DECLS + +#endif /* MCA_OP_CUDA_EXPORT_H */ diff --git a/ompi/mca/op/cuda/op_cuda_component.c b/ompi/mca/op/cuda/op_cuda_component.c new file mode 100644 index 00000000000..5070e8a4c94 --- /dev/null +++ b/ompi/mca/op/cuda/op_cuda_component.c @@ -0,0 +1,196 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * Copyright (c) 2021 Cisco Systems, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +/** @file + * + * This is the "cuda" op component source code. + * + */ + +#include "ompi_config.h" + +#include "opal/util/printf.h" + +#include "ompi/constants.h" +#include "ompi/op/op.h" +#include "ompi/mca/op/op.h" +#include "ompi/mca/op/base/base.h" +#include "ompi/mca/op/cuda/op_cuda.h" + +#include + +static int cuda_component_open(void); +static int cuda_component_close(void); +static int cuda_component_init_query(bool enable_progress_threads, + bool enable_mpi_thread_multiple); +static struct ompi_op_base_module_1_0_0_t * + cuda_component_op_query(struct ompi_op_t *op, int *priority); +static int cuda_component_register(void); + +ompi_op_cuda_component_t mca_op_cuda_component = { + { + .opc_version = { + OMPI_OP_BASE_VERSION_1_0_0, + + .mca_component_name = "cuda", + MCA_BASE_MAKE_VERSION(component, OMPI_MAJOR_VERSION, OMPI_MINOR_VERSION, + OMPI_RELEASE_VERSION), + .mca_open_component = cuda_component_open, + .mca_close_component = cuda_component_close, + .mca_register_component_params = cuda_component_register, + }, + .opc_data = { + /* The component is checkpoint ready */ + MCA_BASE_METADATA_PARAM_CHECKPOINT + }, + + .opc_init_query = cuda_component_init_query, + .opc_op_query = cuda_component_op_query, + }, + .cu_max_num_blocks = -1, + .cu_max_num_threads = -1, + .cu_max_threads_per_block = NULL, + .cu_max_blocks = NULL, + .cu_devices = NULL, + .cu_num_devices = 0, +}; + +/* + * Component open + */ +static int cuda_component_open(void) +{ + return OMPI_SUCCESS; +} + +/* + * Component close + */ +static int cuda_component_close(void) +{ + if (mca_op_cuda_component.cu_num_devices > 0) { + free(mca_op_cuda_component.cu_max_threads_per_block); + mca_op_cuda_component.cu_max_threads_per_block = NULL; + free(mca_op_cuda_component.cu_max_blocks); + mca_op_cuda_component.cu_max_blocks = NULL; + free(mca_op_cuda_component.cu_devices); + mca_op_cuda_component.cu_devices = NULL; + mca_op_cuda_component.cu_num_devices = 0; + } + + return OMPI_SUCCESS; +} + +/* + * Register MCA params. + */ +static int +cuda_component_register(void) +{ + mca_base_var_enum_flag_t *new_enum_flag = NULL; + (void) mca_base_component_var_register(&mca_op_cuda_component.super.opc_version, + "max_num_blocks", + "Maximum number of thread blocks in kernels (-1: device limit)", + MCA_BASE_VAR_TYPE_INT, + &(new_enum_flag->super), 0, 0, + OPAL_INFO_LVL_4, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_op_cuda_component.cu_max_num_blocks); + + (void) mca_base_component_var_register(&mca_op_cuda_component.super.opc_version, + "max_num_threads", + "Maximum number of threads per block in kernels (-1: device limit)", + MCA_BASE_VAR_TYPE_INT, + &(new_enum_flag->super), 0, 0, + OPAL_INFO_LVL_4, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_op_cuda_component.cu_max_num_threads); + + return OMPI_SUCCESS; +} + + +/* + * Query whether this component wants to be used in this process. + */ +static int +cuda_component_init_query(bool enable_progress_threads, + bool enable_mpi_thread_multiple) +{ + int num_devices; + int rc; + int prio_lo, prio_hi; + // TODO: is this init needed here? + cuInit(0); + CHECK(cuDeviceGetCount, (&num_devices)); + mca_op_cuda_component.cu_num_devices = num_devices; + mca_op_cuda_component.cu_devices = (CUdevice*)malloc(num_devices*sizeof(CUdevice)); + mca_op_cuda_component.cu_max_threads_per_block = (int*)malloc(num_devices*sizeof(int)); + mca_op_cuda_component.cu_max_blocks = (int*)malloc(num_devices*sizeof(int)); + for (int i = 0; i < num_devices; ++i) { + CHECK(cuDeviceGet, (&mca_op_cuda_component.cu_devices[i], i)); + rc = cuDeviceGetAttribute(&mca_op_cuda_component.cu_max_threads_per_block[i], + CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, + mca_op_cuda_component.cu_devices[i]); + if (CUDA_SUCCESS != rc) { + /* fall-back to value that should work on every device */ + mca_op_cuda_component.cu_max_threads_per_block[i] = 512; + } + if (-1 < mca_op_cuda_component.cu_max_num_threads) { + if (mca_op_cuda_component.cu_max_threads_per_block[i] >= mca_op_cuda_component.cu_max_num_threads) { + mca_op_cuda_component.cu_max_threads_per_block[i] = mca_op_cuda_component.cu_max_num_threads; + } + } + + rc = cuDeviceGetAttribute(&mca_op_cuda_component.cu_max_blocks[i], + CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, + mca_op_cuda_component.cu_devices[i]); + if (CUDA_SUCCESS != rc) { + /* fall-back to value that should work on every device */ + mca_op_cuda_component.cu_max_blocks[i] = 512; + } + if (-1 < mca_op_cuda_component.cu_max_num_blocks) { + if (mca_op_cuda_component.cu_max_blocks[i] >= mca_op_cuda_component.cu_max_num_blocks) { + mca_op_cuda_component.cu_max_blocks[i] = mca_op_cuda_component.cu_max_num_blocks; + } + } + } + + return OMPI_SUCCESS; +} + +/* + * Query whether this component can be used for a specific op + */ +static struct ompi_op_base_module_1_0_0_t* +cuda_component_op_query(struct ompi_op_t *op, int *priority) +{ + ompi_op_base_module_t *module = NULL; + + module = OBJ_NEW(ompi_op_base_module_t); + module->opm_device_enabled = true; + for (int i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { + module->opm_stream_fns[i] = ompi_op_cuda_functions[op->o_f_to_c_index][i]; + module->opm_3buff_stream_fns[i] = ompi_op_cuda_3buff_functions[op->o_f_to_c_index][i]; + + if( NULL != module->opm_fns[i] ) { + OBJ_RETAIN(module); + } + if( NULL != module->opm_3buff_fns[i] ) { + OBJ_RETAIN(module); + } + } + *priority = 50; + return (ompi_op_base_module_1_0_0_t *) module; +} diff --git a/ompi/mca/op/cuda/op_cuda_functions.c b/ompi/mca/op/cuda/op_cuda_functions.c new file mode 100644 index 00000000000..26e54cb0851 --- /dev/null +++ b/ompi/mca/op/cuda/op_cuda_functions.c @@ -0,0 +1,1744 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#ifdef HAVE_SYS_TYPES_H +#include +#endif +#include "opal/util/output.h" + + +#include "ompi/op/op.h" +#include "ompi/mca/op/op.h" +#include "ompi/mca/op/base/base.h" +#include "ompi/mca/op/cuda/op_cuda.h" +#include "opal/mca/accelerator/accelerator.h" + +#include "ompi/mca/op/cuda/op_cuda.h" +#include "ompi/mca/op/cuda/op_cuda_impl.h" + + +static inline void device_op_pre(const void *orig_source1, + void **source1, + int *source1_device, + const void *orig_source2, + void **source2, + int *source2_device, + void *orig_target, + void **target, + int *target_device, + int count, + struct ompi_datatype_t *dtype, + int *threads_per_block, + int *max_blocks, + int *device, + opal_accelerator_stream_t *stream) +{ + uint64_t target_flags = -1, source1_flags = -1, source2_flags = -1; + int target_rc, source1_rc, source2_rc = -1; + + *target = orig_target; + *source1 = (void*)orig_source1; + if (NULL != orig_source2) { + *source2 = (void*)orig_source2; + } + + if (*device != MCA_ACCELERATOR_NO_DEVICE_ID) { + /* we got the device from the caller, just adjust the output parameters */ + *target_device = *device; + *source1_device = *device; + if (NULL != source2_device) { + *source2_device = *device; + } + } else { + + target_rc = opal_accelerator.check_addr(*target, target_device, &target_flags); + source1_rc = opal_accelerator.check_addr(*source1, source1_device, &source1_flags); + *device = *target_device; + + if (NULL != orig_source2) { + source2_rc = opal_accelerator.check_addr(*source2, source2_device, &source2_flags); + } + + if (0 == target_rc && 0 == source1_rc && 0 == source2_rc) { + /* no buffers are on any device, select device 0 */ + *device = 0; + } else if (*target_device == -1) { + if (*source1_device == -1 && NULL != orig_source2) { + *device = *source2_device; + } else { + *device = *source1_device; + } + } + + if (0 == target_rc || 0 == source1_rc || *target_device != *source1_device) { + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + if (0 == target_rc) { + // allocate memory on the device for the target buffer + opal_accelerator.mem_alloc_stream(*device, target, nbytes, stream); + CHECK(cuMemcpyHtoDAsync, ((CUdeviceptr)*target, orig_target, nbytes, *(CUstream*)stream->stream)); + *target_device = -1; // mark target device as host + } + + if (0 == source1_rc || *device != *source1_device) { + // allocate memory on the device for the source buffer + opal_accelerator.mem_alloc_stream(*device, source1, nbytes, stream); + if (0 == source1_rc) { + /* copy from host to device */ + CHECK(cuMemcpyHtoDAsync, ((CUdeviceptr)*source1, orig_source1, nbytes, *(CUstream*)stream->stream)); + } else { + /* copy from one device to another device */ + /* TODO: does this actually work? Can we enable P2P? */ + CHECK(cuMemcpyDtoDAsync, ((CUdeviceptr)*source1, (CUdeviceptr)orig_source1, nbytes, *(CUstream*)stream->stream)); + } + } + + } + if (NULL != source2_device && *target_device != *source2_device) { + // allocate memory on the device for the source buffer + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + opal_accelerator.mem_alloc_stream(*device, source2, nbytes, stream); + if (0 == source2_rc) { + /* copy from host to device */ + //printf("copying source from host to device %d\n", *device); + CHECK(cuMemcpyHtoDAsync, ((CUdeviceptr)*source2, orig_source2, nbytes, *(CUstream*)stream->stream)); + } else { + /* copy from one device to another device */ + /* TODO: does this actually work? Can we enable P2P? */ + CHECK(cuMemcpyDtoDAsync, ((CUdeviceptr)*source2, (CUdeviceptr)orig_source2, nbytes, *(CUstream*)stream->stream)); + } + } + } + *threads_per_block = mca_op_cuda_component.cu_max_threads_per_block[*device]; + *max_blocks = mca_op_cuda_component.cu_max_blocks[*device]; +} + +static inline void device_op_post(void *source1, + int source1_device, + void *source2, + int source2_device, + void *orig_target, + void *target, + int target_device, + int count, + struct ompi_datatype_t *dtype, + int device, + opal_accelerator_stream_t *stream) +{ + if (MCA_ACCELERATOR_NO_DEVICE_ID == target_device) { + + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + CHECK(cuMemcpyDtoHAsync, (orig_target, (CUdeviceptr)target, nbytes, *(CUstream *)stream->stream)); + } + + if (MCA_ACCELERATOR_NO_DEVICE_ID == target_device) { + opal_accelerator.mem_release_stream(device, target, stream); + //CHECK(cuMemFreeAsync, ((CUdeviceptr)target, mca_op_cuda_component.cu_stream)); + } + if (source1_device != device) { + opal_accelerator.mem_release_stream(device, source1, stream); + //CHECK(cuMemFreeAsync, ((CUdeviceptr)source, mca_op_cuda_component.cu_stream)); + } + if (NULL != source2 && source2_device != device) { + opal_accelerator.mem_release_stream(device, source2, stream); + //CHECK(cuMemFreeAsync, ((CUdeviceptr)source, mca_op_cuda_component.cu_stream)); + } +} + +#define FUNC(name, type_name, type) \ + static \ + void ompi_op_cuda_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + int threads_per_block, max_blocks; \ + int source_device, target_device; \ + type *source, *target; \ + int n = *count; \ + device_op_pre(in, (void**)&source, &source_device, NULL, NULL, NULL, \ + inout, (void**)&target, &target_device, \ + n, *dtype, \ + &threads_per_block, &max_blocks, &device, stream); \ + CUstream *custream = (CUstream*)stream->stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_submit(source, target, n, threads_per_block, max_blocks, *custream); \ + device_op_post(source, source_device, NULL, -1, inout, target, target_device, n, *dtype, device, stream); \ + } + +#define OP_FUNC(name, type_name, type, op, ...) FUNC(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* reuse the macro above, no work is actually done so we don't care about the func */ +#define FUNC_FUNC(name, type_name, type, ...) FUNC(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_FUNC(name, type_name, op) FUNC(name, type_name, ompi_op_predefined_##type_name##_t) + +/* Dispatch Fortran types to C types */ +#define FORT_INT_FUNC(name, type_name, type) \ + static \ + void ompi_op_cuda_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t)); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_cuda_2buff_##name##_int8_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_cuda_2buff_##name##_int16_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_cuda_2buff_##name##_int32_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_cuda_2buff_##name##_int64_t(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +/* Dispatch Fortran types to C types */ +#define FORT_FLOAT_FUNC(name, type_name, type) \ + static \ + void ompi_op_cuda_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(long double)); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_cuda_2buff_##name##_float(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_cuda_2buff_##name##_double(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(long double): \ + ompi_op_cuda_2buff_##name##_long_double(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +/************************************************************************* + * Max + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC(max, int8_t, int8_t) +FUNC_FUNC(max, uint8_t, uint8_t) +FUNC_FUNC(max, int16_t, int16_t) +FUNC_FUNC(max, uint16_t, uint16_t) +FUNC_FUNC(max, int32_t, int32_t) +FUNC_FUNC(max, uint32_t, uint32_t) +FUNC_FUNC(max, int64_t, int64_t) +FUNC_FUNC(max, uint64_t, uint64_t) +FUNC_FUNC(max, long, long) +FUNC_FUNC(max, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(max, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(max, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(max, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(max, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(max, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(max, fortran_integer16, ompi_fortran_integer16_t) +#endif + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC(max, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC(max, float, float) +FUNC_FUNC(max, double, double) +FUNC_FUNC(max, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(max, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(max, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(max, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(max, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(max, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(max, fortran_real16, ompi_fortran_real16_t) +#endif + + +/************************************************************************* + * Min + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC(min, int8_t, int8_t) +FUNC_FUNC(min, uint8_t, uint8_t) +FUNC_FUNC(min, int16_t, int16_t) +FUNC_FUNC(min, uint16_t, uint16_t) +FUNC_FUNC(min, int32_t, int32_t) +FUNC_FUNC(min, uint32_t, uint32_t) +FUNC_FUNC(min, int64_t, int64_t) +FUNC_FUNC(min, uint64_t, uint64_t) +FUNC_FUNC(min, long, long) +FUNC_FUNC(min, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(min, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(min, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(min, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(min, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(min, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(min, fortran_integer16, ompi_fortran_integer16_t) +#endif + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC(min, short_float, opal_short_float_t) +#endif +#endif // 0 + +FUNC_FUNC(min, float, float) +FUNC_FUNC(min, double, double) +FUNC_FUNC(min, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(min, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(min, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(min, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(min, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(min, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(min, fortran_real16, ompi_fortran_real16_t) +#endif + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC(sum, int8_t, int8_t, +=) +OP_FUNC(sum, uint8_t, uint8_t, +=) +OP_FUNC(sum, int16_t, int16_t, +=) +OP_FUNC(sum, uint16_t, uint16_t, +=) +OP_FUNC(sum, int32_t, int32_t, +=) +OP_FUNC(sum, uint32_t, uint32_t, +=) +OP_FUNC(sum, int64_t, int64_t, +=) +OP_FUNC(sum, uint64_t, uint64_t, +=) +OP_FUNC(sum, long, long, +=) +OP_FUNC(sum, ulong, unsigned long, +=) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(sum, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(sum, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(sum, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(sum, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(sum, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(sum, fortran_integer16, ompi_fortran_integer16_t) +#endif + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC(sum, short_float, short float, +=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC(sum, short_float, opal_short_float_t, +=) +#endif +#endif // 0 + +OP_FUNC(sum, float, float, +=) +OP_FUNC(sum, double, double, +=) +OP_FUNC(sum, long_double, long double, +=) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(sum, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(sum, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(sum, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(sum, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(sum, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(sum, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(sum, c_short_float_complex, short float _Complex, +=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(sum, c_long_double_complex, long double _Complex, +=) +#endif // 0 + +FUNC_FUNC(sum, c_float_complex, cuFloatComplex) +FUNC_FUNC(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC(prod, int8_t, int8_t, *=) +OP_FUNC(prod, uint8_t, uint8_t, *=) +OP_FUNC(prod, int16_t, int16_t, *=) +OP_FUNC(prod, uint16_t, uint16_t, *=) +OP_FUNC(prod, int32_t, int32_t, *=) +OP_FUNC(prod, uint32_t, uint32_t, *=) +OP_FUNC(prod, int64_t, int64_t, *=) +OP_FUNC(prod, uint64_t, uint64_t, *=) +OP_FUNC(prod, long, long, *=) +OP_FUNC(prod, ulong, unsigned long, *=) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(prod, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(prod, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(prod, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(prod, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(prod, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(prod, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ + +#if 0 +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC(prod, short_float, short float, *=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC(prod, short_float, opal_short_float_t, *=) +#endif +#endif // 0 + +OP_FUNC(prod, float, float, *=) +OP_FUNC(prod, double, double, *=) +OP_FUNC(prod, long_double, long double, *=) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(prod, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(prod, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(prod, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(prod, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(prod, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(prod, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(prod, c_short_float_complex, short float _Complex, *=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(prod, c_long_double_complex, long double _Complex, *=) +#endif // 0 + +FUNC_FUNC(prod, c_float_complex, cuFloatComplex) +FUNC_FUNC(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC(land, int8_t, int8_t) +FUNC_FUNC(land, uint8_t, uint8_t) +FUNC_FUNC(land, int16_t, int16_t) +FUNC_FUNC(land, uint16_t, uint16_t) +FUNC_FUNC(land, int32_t, int32_t) +FUNC_FUNC(land, uint32_t, uint32_t) +FUNC_FUNC(land, int64_t, int64_t) +FUNC_FUNC(land, uint64_t, uint64_t) +FUNC_FUNC(land, long, long) +FUNC_FUNC(land, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(land, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC(lor, int8_t, int8_t) +FUNC_FUNC(lor, uint8_t, uint8_t) +FUNC_FUNC(lor, int16_t, int16_t) +FUNC_FUNC(lor, uint16_t, uint16_t) +FUNC_FUNC(lor, int32_t, int32_t) +FUNC_FUNC(lor, uint32_t, uint32_t) +FUNC_FUNC(lor, int64_t, int64_t) +FUNC_FUNC(lor, uint64_t, uint64_t) +FUNC_FUNC(lor, long, long) +FUNC_FUNC(lor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(lor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC(lxor, int8_t, int8_t) +FUNC_FUNC(lxor, uint8_t, uint8_t) +FUNC_FUNC(lxor, int16_t, int16_t) +FUNC_FUNC(lxor, uint16_t, uint16_t) +FUNC_FUNC(lxor, int32_t, int32_t) +FUNC_FUNC(lxor, uint32_t, uint32_t) +FUNC_FUNC(lxor, int64_t, int64_t) +FUNC_FUNC(lxor, uint64_t, uint64_t) +FUNC_FUNC(lxor, long, long) +FUNC_FUNC(lxor, ulong, unsigned long) + + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(lxor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC(band, int8_t, int8_t) +FUNC_FUNC(band, uint8_t, uint8_t) +FUNC_FUNC(band, int16_t, int16_t) +FUNC_FUNC(band, uint16_t, uint16_t) +FUNC_FUNC(band, int32_t, int32_t) +FUNC_FUNC(band, uint32_t, uint32_t) +FUNC_FUNC(band, int64_t, int64_t) +FUNC_FUNC(band, uint64_t, uint64_t) +FUNC_FUNC(band, long, long) +FUNC_FUNC(band, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(band, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(band, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(band, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(band, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(band, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(band, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC(bor, int8_t, int8_t) +FUNC_FUNC(bor, uint8_t, uint8_t) +FUNC_FUNC(bor, int16_t, int16_t) +FUNC_FUNC(bor, uint16_t, uint16_t) +FUNC_FUNC(bor, int32_t, int32_t) +FUNC_FUNC(bor, uint32_t, uint32_t) +FUNC_FUNC(bor, int64_t, int64_t) +FUNC_FUNC(bor, uint64_t, uint64_t) +FUNC_FUNC(bor, long, long) +FUNC_FUNC(bor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(bor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(bor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(bor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(bor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(bor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(bor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC(bxor, int8_t, int8_t) +FUNC_FUNC(bxor, uint8_t, uint8_t) +FUNC_FUNC(bxor, int16_t, int16_t) +FUNC_FUNC(bxor, uint16_t, uint16_t) +FUNC_FUNC(bxor, int32_t, int32_t) +FUNC_FUNC(bxor, uint32_t, uint32_t) +FUNC_FUNC(bxor, int64_t, int64_t) +FUNC_FUNC(bxor, uint64_t, uint64_t) +FUNC_FUNC(bxor, long, long) +FUNC_FUNC(bxor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(bxor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(bxor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(bxor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(bxor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(bxor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(bxor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC(maxloc, float_int, >) +LOC_FUNC(maxloc, double_int, >) +LOC_FUNC(maxloc, long_int, >) +LOC_FUNC(maxloc, 2int, >) +LOC_FUNC(maxloc, short_int, >) +LOC_FUNC(maxloc, long_double_int, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC(minloc, float_int, <) +LOC_FUNC(minloc, double_int, <) +LOC_FUNC(minloc, long_int, <) +LOC_FUNC(minloc, 2int, <) +LOC_FUNC(minloc, short_int, <) +LOC_FUNC(minloc, long_double_int, <) + + + +/* + * This is a three buffer (2 input and 1 output) version of the reduction + * routines, needed for some optimizations. + */ +#define FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_cuda_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + int threads_per_block, max_blocks; \ + int source1_device, source2_device, target_device; \ + type *source1, *source2, *target; \ + int n = *count; \ + device_op_pre(in1, (void**)&source1, &source1_device, \ + in2, (void**)&source2, &source2_device, \ + out, (void**)&target, &target_device, \ + n, *dtype, \ + &threads_per_block, &max_blocks, &device, stream); \ + CUstream *custream = (CUstream*)stream->stream; \ + ompi_op_cuda_3buff_##name##_##type_name##_submit(source1, source2, target, n, threads_per_block, max_blocks, *custream);\ + device_op_post(source1, source1_device, source2, source2_device, out, target, target_device, n, *dtype, device, stream);\ + } + + +#define OP_FUNC_3BUF(name, type_name, type, op, ...) FUNC_3BUF(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* reuse the macro above, no work is actually done so we don't care about the func */ +#define FUNC_FUNC_3BUF(name, type_name, type, ...) FUNC_3BUF(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_FUNC_3BUF(name, type_name, op) FUNC_3BUF(name, type_name, ompi_op_predefined_##type_name##_t) + + +/* Dispatch Fortran types to C types */ +#define FORT_INT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_cuda_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t)); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_cuda_3buff_##name##_int8_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_cuda_3buff_##name##_int16_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_cuda_3buff_##name##_int32_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_cuda_3buff_##name##_int64_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + } \ + } + +/* Dispatch Fortran types to C types */ +#define FORT_FLOAT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_cuda_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(long double)); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_cuda_3buff_##name##_float(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_cuda_3buff_##name##_double(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(long double): \ + ompi_op_cuda_3buff_##name##_long_double(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + } \ + } + + +/************************************************************************* + * Max + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(max, int8_t, int8_t) +FUNC_FUNC_3BUF(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF(max, int16_t, int16_t) +FUNC_FUNC_3BUF(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF(max, int32_t, int32_t) +FUNC_FUNC_3BUF(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF(max, int64_t, int64_t) +FUNC_FUNC_3BUF(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF(max, long, long) +FUNC_FUNC_3BUF(max, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(max, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(max, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(max, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(max, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(max, fortran_integer8, ompi_fortran_integer8_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(max, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_3BUF(max, float, float) +FUNC_FUNC_3BUF(max, double, double) +FUNC_FUNC_3BUF(max, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(max, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(max, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(max, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(max, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(max, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(max, fortran_real16, ompi_fortran_real16_t) +#endif + + +/************************************************************************* + * Min + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(min, int8_t, int8_t) +FUNC_FUNC_3BUF(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF(min, int16_t, int16_t) +FUNC_FUNC_3BUF(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF(min, int32_t, int32_t) +FUNC_FUNC_3BUF(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF(min, int64_t, int64_t) +FUNC_FUNC_3BUF(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF(min, long, long) +FUNC_FUNC_3BUF(min, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(min, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(min, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(min, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(min, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(min, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(min, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(min, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_3BUF(min, float, float) +FUNC_FUNC_3BUF(min, double, double) +FUNC_FUNC_3BUF(min, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(min, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(min, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(min, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(min, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(min, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(min, fortran_real16, ompi_fortran_real16_t) +#endif + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(sum, int8_t, int8_t, +) +OP_FUNC_3BUF(sum, uint8_t, uint8_t, +) +OP_FUNC_3BUF(sum, int16_t, int16_t, +) +OP_FUNC_3BUF(sum, uint16_t, uint16_t, +) +OP_FUNC_3BUF(sum, int32_t, int32_t, +) +OP_FUNC_3BUF(sum, uint32_t, uint32_t, +) +OP_FUNC_3BUF(sum, int64_t, int64_t, +) +OP_FUNC_3BUF(sum, uint64_t, uint64_t, +) +OP_FUNC_3BUF(sum, long, long, +) +OP_FUNC_3BUF(sum, ulong, unsigned long, +) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(sum, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(sum, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(sum, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(sum, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(sum, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(sum, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF(sum, short_float, short float, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF(sum, short_float, opal_short_float_t, +) +#endif +#endif // 0 +OP_FUNC_3BUF(sum, float, float, +) +OP_FUNC_3BUF(sum, double, double, +) +OP_FUNC_3BUF(sum, long_double, long double, +) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(sum, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(sum, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(sum, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(sum, c_short_float_complex, short float _Complex, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(sum, c_float_complex, float _Complex, +) +OP_FUNC_3BUF(sum, c_double_complex, double _Complex, +) +OP_FUNC_3BUF(sum, c_long_double_complex, long double _Complex, +) +#endif // 0 + +FUNC_FUNC_3BUF(sum, c_float_complex, cuFloatComplex) +FUNC_FUNC_3BUF(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(prod, int8_t, int8_t, *) +OP_FUNC_3BUF(prod, uint8_t, uint8_t, *) +OP_FUNC_3BUF(prod, int16_t, int16_t, *) +OP_FUNC_3BUF(prod, uint16_t, uint16_t, *) +OP_FUNC_3BUF(prod, int32_t, int32_t, *) +OP_FUNC_3BUF(prod, uint32_t, uint32_t, *) +OP_FUNC_3BUF(prod, int64_t, int64_t, *) +OP_FUNC_3BUF(prod, uint64_t, uint64_t, *) +OP_FUNC_3BUF(prod, long, long, *) +OP_FUNC_3BUF(prod, ulong, unsigned long, *) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(prod, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(prod, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(prod, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(prod, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(prod, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(prod, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FORT_FLOAT_FUNC_3BUF(prod, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FORT_FLOAT_FUNC_3BUF(prod, short_float, opal_short_float_t) +#endif +#endif // 0 +OP_FUNC_3BUF(prod, float, float, *) +OP_FUNC_3BUF(prod, double, double, *) +OP_FUNC_3BUF(prod, long_double, long double, *) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(prod, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(prod, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(prod, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(prod, c_short_float_complex, short float _Complex, *) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(prod, c_long_double_complex, long double _Complex, *) +#endif // 0 + +FUNC_FUNC_3BUF(prod, c_float_complex, cuFloatComplex) +FUNC_FUNC_3BUF(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC_3BUF(land, int8_t, int8_t) +FUNC_FUNC_3BUF(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF(land, int16_t, int16_t) +FUNC_FUNC_3BUF(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF(land, int32_t, int32_t) +FUNC_FUNC_3BUF(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF(land, int64_t, int64_t) +FUNC_FUNC_3BUF(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF(land, long, long) +FUNC_FUNC_3BUF(land, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(land, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC_3BUF(lor, int8_t, int8_t) +FUNC_FUNC_3BUF(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lor, int16_t, int16_t) +FUNC_FUNC_3BUF(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lor, int32_t, int32_t) +FUNC_FUNC_3BUF(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lor, int64_t, int64_t) +FUNC_FUNC_3BUF(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lor, long, long) +FUNC_FUNC_3BUF(lor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(lor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC_3BUF(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lxor, long, long) +FUNC_FUNC_3BUF(lxor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(lxor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC_3BUF(band, int8_t, int8_t) +FUNC_FUNC_3BUF(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF(band, int16_t, int16_t) +FUNC_FUNC_3BUF(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF(band, int32_t, int32_t) +FUNC_FUNC_3BUF(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF(band, int64_t, int64_t) +FUNC_FUNC_3BUF(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF(band, long, long) +FUNC_FUNC_3BUF(band, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(band, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(band, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(band, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(band, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(band, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(band, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC_3BUF(bor, int8_t, int8_t) +FUNC_FUNC_3BUF(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bor, int16_t, int16_t) +FUNC_FUNC_3BUF(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bor, int32_t, int32_t) +FUNC_FUNC_3BUF(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bor, int64_t, int64_t) +FUNC_FUNC_3BUF(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bor, long, long) +FUNC_FUNC_3BUF(bor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(bor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(bor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(bor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(bor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(bor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(bor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC_3BUF(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bxor, long, long) +FUNC_FUNC_3BUF(bxor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(bxor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(bxor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(bxor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(bxor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(bxor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(bxor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_3BUF(maxloc, float_int, >) +LOC_FUNC_3BUF(maxloc, double_int, >) +LOC_FUNC_3BUF(maxloc, long_int, >) +LOC_FUNC_3BUF(maxloc, 2int, >) +LOC_FUNC_3BUF(maxloc, short_int, >) +LOC_FUNC_3BUF(maxloc, long_double_int, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_3BUF(minloc, float_int, <) +LOC_FUNC_3BUF(minloc, double_int, <) +LOC_FUNC_3BUF(minloc, long_int, <) +LOC_FUNC_3BUF(minloc, 2int, <) +LOC_FUNC_3BUF(minloc, short_int, <) +LOC_FUNC_3BUF(minloc, long_double_int, <) + +/* + * Helpful defines, because there's soooo many names! + * + * **NOTE** These #define's used to be strictly ordered but the use of + * designated initializers removed this restrictions. When adding new + * operators ALWAYS use a designated initializer! + */ + +/** C integer ***********************************************************/ +#define C_INTEGER(name, ftype) \ + [OMPI_OP_BASE_TYPE_INT8_T] = ompi_op_cuda_##ftype##_##name##_int8_t, \ + [OMPI_OP_BASE_TYPE_UINT8_T] = ompi_op_cuda_##ftype##_##name##_uint8_t, \ + [OMPI_OP_BASE_TYPE_INT16_T] = ompi_op_cuda_##ftype##_##name##_int16_t, \ + [OMPI_OP_BASE_TYPE_UINT16_T] = ompi_op_cuda_##ftype##_##name##_uint16_t, \ + [OMPI_OP_BASE_TYPE_INT32_T] = ompi_op_cuda_##ftype##_##name##_int32_t, \ + [OMPI_OP_BASE_TYPE_UINT32_T] = ompi_op_cuda_##ftype##_##name##_uint32_t, \ + [OMPI_OP_BASE_TYPE_INT64_T] = ompi_op_cuda_##ftype##_##name##_int64_t, \ + [OMPI_OP_BASE_TYPE_LONG] = ompi_op_cuda_##ftype##_##name##_long, \ + [OMPI_OP_BASE_TYPE_UNSIGNED_LONG] = ompi_op_cuda_##ftype##_##name##_ulong, \ + [OMPI_OP_BASE_TYPE_UINT64_T] = ompi_op_cuda_##ftype##_##name##_uint64_t + +/** All the Fortran integers ********************************************/ + +#if OMPI_HAVE_FORTRAN_INTEGER +#define FORTRAN_INTEGER_PLAIN(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer +#else +#define FORTRAN_INTEGER_PLAIN(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +#define FORTRAN_INTEGER1(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer1 +#else +#define FORTRAN_INTEGER1(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +#define FORTRAN_INTEGER2(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer2 +#else +#define FORTRAN_INTEGER2(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +#define FORTRAN_INTEGER4(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer4 +#else +#define FORTRAN_INTEGER4(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +#define FORTRAN_INTEGER8(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer8 +#else +#define FORTRAN_INTEGER8(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +#define FORTRAN_INTEGER16(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer16 +#else +#define FORTRAN_INTEGER16(name, ftype) NULL +#endif + +#define FORTRAN_INTEGER(name, ftype) \ + [OMPI_OP_BASE_TYPE_INTEGER] = FORTRAN_INTEGER_PLAIN(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER1] = FORTRAN_INTEGER1(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER2] = FORTRAN_INTEGER2(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER4] = FORTRAN_INTEGER4(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER8] = FORTRAN_INTEGER8(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER16] = FORTRAN_INTEGER16(name, ftype) + +/** All the Fortran reals ***********************************************/ + +#if OMPI_HAVE_FORTRAN_REAL +#define FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_real +#else +#define FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +#define FLOATING_POINT_FORTRAN_REAL2(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_real2 +#else +#define FLOATING_POINT_FORTRAN_REAL2(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +#define FLOATING_POINT_FORTRAN_REAL4(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_real4 +#else +#define FLOATING_POINT_FORTRAN_REAL4(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +#define FLOATING_POINT_FORTRAN_REAL8(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_real8 +#else +#define FLOATING_POINT_FORTRAN_REAL8(name, ftype) NULL +#endif +/* If: + - we have fortran REAL*16, *and* + - fortran REAL*16 matches the bit representation of the + corresponding C type + Only then do we put in function pointers for REAL*16 reductions. + Otherwise, just put in NULL. */ +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +#define FLOATING_POINT_FORTRAN_REAL16(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_real16 +#else +#define FLOATING_POINT_FORTRAN_REAL16(name, ftype) NULL +#endif + +#define FLOATING_POINT_FORTRAN_REAL(name, ftype) \ + [OMPI_OP_BASE_TYPE_REAL] = FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL2] = FLOATING_POINT_FORTRAN_REAL2(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL4] = FLOATING_POINT_FORTRAN_REAL4(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL8] = FLOATING_POINT_FORTRAN_REAL8(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL16] = FLOATING_POINT_FORTRAN_REAL16(name, ftype) + +/** Fortran double precision ********************************************/ + +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +#define FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype) \ + ompi_op_cuda_##ftype##_##name##_fortran_double_precision +#else +#define FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype) NULL +#endif + +/** Floating point, including all the Fortran reals *********************/ + +//#if defined(HAVE_SHORT_FLOAT) || defined(HAVE_OPAL_SHORT_FLOAT_T) +//#define SHORT_FLOAT(name, ftype) ompi_op_cuda_##ftype##_##name##_short_float +//#else +#define SHORT_FLOAT(name, ftype) NULL +//#endif +#define FLOAT(name, ftype) ompi_op_cuda_##ftype##_##name##_float +#define DOUBLE(name, ftype) ompi_op_cuda_##ftype##_##name##_double +#define LONG_DOUBLE(name, ftype) ompi_op_cuda_##ftype##_##name##_long_double + +#define FLOATING_POINT(name, ftype) \ + [OMPI_OP_BASE_TYPE_SHORT_FLOAT] = SHORT_FLOAT(name, ftype), \ + [OMPI_OP_BASE_TYPE_FLOAT] = FLOAT(name, ftype), \ + [OMPI_OP_BASE_TYPE_DOUBLE] = DOUBLE(name, ftype), \ + FLOATING_POINT_FORTRAN_REAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_DOUBLE_PRECISION] = FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype), \ + [OMPI_OP_BASE_TYPE_LONG_DOUBLE] = LONG_DOUBLE(name, ftype) + +/** Fortran logical *****************************************************/ + +#if OMPI_HAVE_FORTRAN_LOGICAL +#define FORTRAN_LOGICAL(name, ftype) \ + ompi_op_cuda_##ftype##_##name##_fortran_logical /* OMPI_OP_CUDA_TYPE_LOGICAL */ +#else +#define FORTRAN_LOGICAL(name, ftype) NULL +#endif + +#define LOGICAL(name, ftype) \ + [OMPI_OP_BASE_TYPE_LOGICAL] = FORTRAN_LOGICAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_BOOL] = ompi_op_cuda_##ftype##_##name##_bool + +/** Complex *****************************************************/ +#if 0 + +#if defined(HAVE_SHORT_FLOAT__COMPLEX) || defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +#define SHORT_FLOAT_COMPLEX(name, ftype) ompi_op_cuda_##ftype##_##name##_c_short_float_complex +#else +#define SHORT_FLOAT_COMPLEX(name, ftype) NULL +#endif +#define LONG_DOUBLE_COMPLEX(name, ftype) ompi_op_cuda_##ftype##_##name##_c_long_double_complex +#else +#define SHORT_FLOAT_COMPLEX(name, ftype) NULL +#define LONG_DOUBLE_COMPLEX(name, ftype) NULL +#endif // 0 +#define FLOAT_COMPLEX(name, ftype) ompi_op_cuda_##ftype##_##name##_c_float_complex +#define DOUBLE_COMPLEX(name, ftype) ompi_op_cuda_##ftype##_##name##_c_double_complex + +#define COMPLEX(name, ftype) \ + [OMPI_OP_BASE_TYPE_C_SHORT_FLOAT_COMPLEX] = SHORT_FLOAT_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_FLOAT_COMPLEX] = FLOAT_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_DOUBLE_COMPLEX] = DOUBLE_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_LONG_DOUBLE_COMPLEX] = LONG_DOUBLE_COMPLEX(name, ftype) + +/** Byte ****************************************************************/ + +#define BYTE(name, ftype) \ + [OMPI_OP_BASE_TYPE_BYTE] = ompi_op_cuda_##ftype##_##name##_byte + +/** Fortran complex *****************************************************/ +/** Fortran "2" types ***************************************************/ + +#if OMPI_HAVE_FORTRAN_REAL && OMPI_SIZEOF_FLOAT == OMPI_SIZEOF_FORTRAN_REAL +#define TWOLOC_FORTRAN_2REAL(name, ftype) ompi_op_cuda_##ftype##_##name##_2double_precision +#else +#define TWOLOC_FORTRAN_2REAL(name, ftype) NULL +#endif + +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION && OMPI_SIZEOF_DOUBLE == OMPI_SIZEOF_FORTRAN_DOUBLE_PRECISION +#define TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype) ompi_op_cuda_##ftype##_##name##_2double_precision +#else +#define TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER && OMPI_SIZEOF_INT == OMPI_SIZEOF_FORTRAN_INTEGER +#define TWOLOC_FORTRAN_2INTEGER(name, ftype) ompi_op_cuda_##ftype##_##name##_2int +#else +#define TWOLOC_FORTRAN_2INTEGER(name, ftype) NULL +#endif + +/** All "2" types *******************************************************/ + +#define TWOLOC(name, ftype) \ + [OMPI_OP_BASE_TYPE_2REAL] = TWOLOC_FORTRAN_2REAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_2DOUBLE_PRECISION] = TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype), \ + [OMPI_OP_BASE_TYPE_2INTEGER] = TWOLOC_FORTRAN_2INTEGER(name, ftype), \ + [OMPI_OP_BASE_TYPE_FLOAT_INT] = ompi_op_cuda_##ftype##_##name##_float_int, \ + [OMPI_OP_BASE_TYPE_DOUBLE_INT] = ompi_op_cuda_##ftype##_##name##_double_int, \ + [OMPI_OP_BASE_TYPE_LONG_INT] = ompi_op_cuda_##ftype##_##name##_long_int, \ + [OMPI_OP_BASE_TYPE_2INT] = ompi_op_cuda_##ftype##_##name##_2int, \ + [OMPI_OP_BASE_TYPE_SHORT_INT] = ompi_op_cuda_##ftype##_##name##_short_int, \ + [OMPI_OP_BASE_TYPE_LONG_DOUBLE_INT] = ompi_op_cuda_##ftype##_##name##_long_double_int + +/* + * MPI_OP_NULL + * All types + */ +#define FLAGS_NO_FLOAT \ + (OMPI_OP_FLAGS_INTRINSIC | OMPI_OP_FLAGS_ASSOC | OMPI_OP_FLAGS_COMMUTE) +#define FLAGS \ + (OMPI_OP_FLAGS_INTRINSIC | OMPI_OP_FLAGS_ASSOC | \ + OMPI_OP_FLAGS_FLOAT_ASSOC | OMPI_OP_FLAGS_COMMUTE) + +ompi_op_base_stream_handler_fn_t ompi_op_cuda_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX] = + { + /* Corresponds to MPI_OP_NULL */ + [OMPI_OP_BASE_FORTRAN_NULL] = { + /* Leaving this empty puts in NULL for all entries */ + NULL, + }, + /* Corresponds to MPI_MAX */ + [OMPI_OP_BASE_FORTRAN_MAX] = { + C_INTEGER(max, 2buff), + FORTRAN_INTEGER(max, 2buff), + FLOATING_POINT(max, 2buff), + }, + /* Corresponds to MPI_MIN */ + [OMPI_OP_BASE_FORTRAN_MIN] = { + C_INTEGER(min, 2buff), + FORTRAN_INTEGER(min, 2buff), + FLOATING_POINT(min, 2buff), + }, + /* Corresponds to MPI_SUM */ + [OMPI_OP_BASE_FORTRAN_SUM] = { + C_INTEGER(sum, 2buff), + FORTRAN_INTEGER(sum, 2buff), + FLOATING_POINT(sum, 2buff), + COMPLEX(sum, 2buff), + }, + /* Corresponds to MPI_PROD */ + [OMPI_OP_BASE_FORTRAN_PROD] = { + C_INTEGER(prod, 2buff), + FORTRAN_INTEGER(prod, 2buff), + FLOATING_POINT(prod, 2buff), + COMPLEX(prod, 2buff), + }, + /* Corresponds to MPI_LAND */ + [OMPI_OP_BASE_FORTRAN_LAND] = { + C_INTEGER(land, 2buff), + LOGICAL(land, 2buff), + }, + /* Corresponds to MPI_BAND */ + [OMPI_OP_BASE_FORTRAN_BAND] = { + C_INTEGER(band, 2buff), + FORTRAN_INTEGER(band, 2buff), + BYTE(band, 2buff), + }, + /* Corresponds to MPI_LOR */ + [OMPI_OP_BASE_FORTRAN_LOR] = { + C_INTEGER(lor, 2buff), + LOGICAL(lor, 2buff), + }, + /* Corresponds to MPI_BOR */ + [OMPI_OP_BASE_FORTRAN_BOR] = { + C_INTEGER(bor, 2buff), + FORTRAN_INTEGER(bor, 2buff), + BYTE(bor, 2buff), + }, + /* Corresponds to MPI_LXOR */ + [OMPI_OP_BASE_FORTRAN_LXOR] = { + C_INTEGER(lxor, 2buff), + LOGICAL(lxor, 2buff), + }, + /* Corresponds to MPI_BXOR */ + [OMPI_OP_BASE_FORTRAN_BXOR] = { + C_INTEGER(bxor, 2buff), + FORTRAN_INTEGER(bxor, 2buff), + BYTE(bxor, 2buff), + }, + /* Corresponds to MPI_MAXLOC */ + [OMPI_OP_BASE_FORTRAN_MAXLOC] = { + TWOLOC(maxloc, 2buff), + }, + /* Corresponds to MPI_MINLOC */ + [OMPI_OP_BASE_FORTRAN_MINLOC] = { + TWOLOC(minloc, 2buff), + }, + /* Corresponds to MPI_REPLACE */ + [OMPI_OP_BASE_FORTRAN_REPLACE] = { + /* (MPI_ACCUMULATE is handled differently than the other + reductions, so just zero out its function + implementations here to ensure that users don't invoke + MPI_REPLACE with any reduction operations other than + ACCUMULATE) */ + NULL, + }, + + }; + +ompi_op_base_3buff_stream_handler_fn_t ompi_op_cuda_3buff_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX] = + { + /* Corresponds to MPI_OP_NULL */ + [OMPI_OP_BASE_FORTRAN_NULL] = { + /* Leaving this empty puts in NULL for all entries */ + NULL, + }, + /* Corresponds to MPI_MAX */ + [OMPI_OP_BASE_FORTRAN_MAX] = { + C_INTEGER(max, 3buff), + FORTRAN_INTEGER(max, 3buff), + FLOATING_POINT(max, 3buff), + }, + /* Corresponds to MPI_MIN */ + [OMPI_OP_BASE_FORTRAN_MIN] = { + C_INTEGER(min, 3buff), + FORTRAN_INTEGER(min, 3buff), + FLOATING_POINT(min, 3buff), + }, + /* Corresponds to MPI_SUM */ + [OMPI_OP_BASE_FORTRAN_SUM] = { + C_INTEGER(sum, 3buff), + FORTRAN_INTEGER(sum, 3buff), + FLOATING_POINT(sum, 3buff), + COMPLEX(sum, 3buff), + }, + /* Corresponds to MPI_PROD */ + [OMPI_OP_BASE_FORTRAN_PROD] = { + C_INTEGER(prod, 3buff), + FORTRAN_INTEGER(prod, 3buff), + FLOATING_POINT(prod, 3buff), + COMPLEX(prod, 3buff), + }, + /* Corresponds to MPI_LAND */ + [OMPI_OP_BASE_FORTRAN_LAND] ={ + C_INTEGER(land, 3buff), + LOGICAL(land, 3buff), + }, + /* Corresponds to MPI_BAND */ + [OMPI_OP_BASE_FORTRAN_BAND] = { + C_INTEGER(band, 3buff), + FORTRAN_INTEGER(band, 3buff), + BYTE(band, 3buff), + }, + /* Corresponds to MPI_LOR */ + [OMPI_OP_BASE_FORTRAN_LOR] = { + C_INTEGER(lor, 3buff), + LOGICAL(lor, 3buff), + }, + /* Corresponds to MPI_BOR */ + [OMPI_OP_BASE_FORTRAN_BOR] = { + C_INTEGER(bor, 3buff), + FORTRAN_INTEGER(bor, 3buff), + BYTE(bor, 3buff), + }, + /* Corresponds to MPI_LXOR */ + [OMPI_OP_BASE_FORTRAN_LXOR] = { + C_INTEGER(lxor, 3buff), + LOGICAL(lxor, 3buff), + }, + /* Corresponds to MPI_BXOR */ + [OMPI_OP_BASE_FORTRAN_BXOR] = { + C_INTEGER(bxor, 3buff), + FORTRAN_INTEGER(bxor, 3buff), + BYTE(bxor, 3buff), + }, + /* Corresponds to MPI_MAXLOC */ + [OMPI_OP_BASE_FORTRAN_MAXLOC] = { + TWOLOC(maxloc, 3buff), + }, + /* Corresponds to MPI_MINLOC */ + [OMPI_OP_BASE_FORTRAN_MINLOC] = { + TWOLOC(minloc, 3buff), + }, + /* Corresponds to MPI_REPLACE */ + [OMPI_OP_BASE_FORTRAN_REPLACE] = { + /* MPI_ACCUMULATE is handled differently than the other + reductions, so just zero out its function + implementations here to ensure that users don't invoke + MPI_REPLACE with any reduction operations other than + ACCUMULATE */ + NULL, + }, + }; diff --git a/ompi/mca/op/cuda/op_cuda_impl.cu b/ompi/mca/op/cuda/op_cuda_impl.cu new file mode 100644 index 00000000000..79c82feaa19 --- /dev/null +++ b/ompi/mca/op/cuda/op_cuda_impl.cu @@ -0,0 +1,1049 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "op_cuda_impl.h" + +#include + +#include + +#define ISSIGNED(x) std::is_signed_v + +template +static inline __device__ constexpr T tmax(T a, T b) { + return (a > b) ? a : b; +} + +template +static inline __device__ constexpr T tmin(T a, T b) { + return (a < b) ? a : b; +} + +template +static inline __device__ constexpr T tsum(T a, T b) { + return a+b; +} + +template +static inline __device__ constexpr T tprod(T a, T b) { + return a*b; +} + +template +static inline __device__ T vmax(const T& a, const T& b) { + return T{tmax(a.x, b.x), tmax(a.y, b.y), tmax(a.z, b.z), tmax(a.w, b.w)}; +} + +template +static inline __device__ T vmin(const T& a, const T& b) { + return T{tmin(a.x, b.x), tmin(a.y, b.y), tmin(a.z, b.z), tmin(a.w, b.w)}; +} + +template +static inline __device__ T vsum(const T& a, const T& b) { + return T{tsum(a.x, b.x), tsum(a.y, b.y), tsum(a.z, b.z), tsum(a.w, b.w)}; +} + +template +static inline __device__ T vprod(const T& a, const T& b) { + return T{(a.x * b.x), (a.y * b.y), (a.z * b.z), (a.w * b.w)}; +} + + +/* TODO: missing support for + * - short float (conditional on whether short float is available) + * - some Fortran types + * - some complex types + */ + +#define USE_VECTORS 1 + +#define OP_FUNC(name, type_name, type, op) \ + static __global__ void \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + /*if (index < n) { int i = index;*/ \ + inout[i] = inout[i] op in[i]; \ + } \ + } \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + int n = count; \ + CUstream s = stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } + + +#if defined(USE_VECTORS) +#define OPV_FUNC(name, type_name, type, vtype, vlen, op) \ + static __global__ void \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n/vlen; i += stride) { \ + vtype vin = ((vtype*)in)[i]; \ + vtype vinout = ((vtype*)inout)[i]; \ + vinout.x = vinout.x op vin.x; \ + vinout.y = vinout.y op vin.y; \ + vinout.z = vinout.z op vin.z; \ + vinout.w = vinout.w op vin.w; \ + ((vtype*)inout)[i] = vinout; \ + } \ + int remainder = n%vlen; \ + if (index == (n/vlen) && remainder != 0) { \ + while(remainder) { \ + int idx = n - remainder--; \ + inout[idx] = inout[idx] op in[idx]; \ + } \ + } \ + } \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int vcount = (count + vlen-1)/vlen; \ + int threads = min(threads_per_block, vcount); \ + int blocks = min((vcount + threads-1) / threads, max_blocks); \ + int n = count; \ + CUstream s = stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } +#else // USE_VECTORS +#define OPV_FUNC(name, type_name, type, vtype, vlen, op) OP_FUNC(name, type_name, type, op) +#endif // USE_VECTORS + +#define FUNC_FUNC_FN(name, type_name, type, fn) \ + static __global__ void \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + inout[i] = fn(inout[i], in[i]); \ + } \ + } \ + void \ + ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + int n = count; \ + CUstream s = stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } + +#define FUNC_FUNC(name, type_name, type) FUNC_FUNC_FN(name, type_name, type, current_func) + +#if defined(USE_VECTORS) +#define VFUNC_FUNC(name, type_name, type, vtype, vlen, vfn, fn) \ + static __global__ void \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n/vlen; i += stride) { \ + ((vtype*)inout)[i] = vfn(((vtype*)inout)[i], ((vtype*)in)[i]); \ + } \ + int remainder = n%vlen; \ + if (index == (n/vlen) && remainder != 0) { \ + while(remainder) { \ + int idx = n - remainder--; \ + inout[idx] = fn(inout[idx], in[idx]); \ + } \ + } \ + } \ + static void \ + ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int vcount = (count + vlen-1)/vlen; \ + int threads = min(threads_per_block, vcount); \ + int blocks = min((vcount + threads-1) / threads, max_blocks); \ + int n = count; \ + CUstream s = stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } +#else +#define VFUNC_FUNC(name, type_name, type, vtype, vlen, vfn, fn) FUNC_FUNC_FN(name, type_name, type, fn) +#endif // defined(USE_VECTORS) + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ + +#define LOC_FUNC(name, type_name, op) \ + static __global__ void \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel(const ompi_op_predefined_##type_name##_t *__restrict__ in, \ + ompi_op_predefined_##type_name##_t *__restrict__ inout, \ + int n) \ + { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + const ompi_op_predefined_##type_name##_t *a = &in[i]; \ + ompi_op_predefined_##type_name##_t *b = &inout[i]; \ + if (a->v op b->v) { \ + b->v = a->v; \ + b->k = a->k; \ + } else if (a->v == b->v) { \ + b->k = (b->k < a->k ? b->k : a->k); \ + } \ + } \ + } \ + void \ + ompi_op_cuda_2buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + CUstream s = stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(a, b, count); \ + } + +#define OPV_DISPATCH(name, type_name, type) \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + static_assert(sizeof(type_name) <= sizeof(unsigned long long), "Unknown size type"); \ + if constexpr(!ISSIGNED(type)) { \ + if constexpr(sizeof(type_name) == sizeof(unsigned char)) { \ + ompi_op_cuda_2buff_##name##_uchar_submit((const unsigned char*)in, (unsigned char*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned short)) { \ + ompi_op_cuda_2buff_##name##_ushort_submit((const unsigned short*)in, (unsigned short*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned int)) { \ + ompi_op_cuda_2buff_##name##_uint_submit((const unsigned int*)in, (unsigned int*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned long)) { \ + ompi_op_cuda_2buff_##name##_ulong_submit((const unsigned long*)in, (unsigned long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned long long)) { \ + ompi_op_cuda_2buff_##name##_ulonglong_submit((const unsigned long long*)in, (unsigned long long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } \ + } else { \ + if constexpr(sizeof(type_name) == sizeof(char)) { \ + ompi_op_cuda_2buff_##name##_char_submit((const char*)in, (char*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(short)) { \ + ompi_op_cuda_2buff_##name##_short_submit((const short*)in, (short*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(int)) { \ + ompi_op_cuda_2buff_##name##_int_submit((const int*)in, (int*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(long)) { \ + ompi_op_cuda_2buff_##name##_long_submit((const long*)in, (long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(long long)) { \ + ompi_op_cuda_2buff_##name##_longlong_submit((const long long*)in, (long long*)inout, count,\ + threads_per_block, \ + max_blocks, stream); \ + } \ + } \ + } + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(max, char, char, char4, 4, vmax, max) +VFUNC_FUNC(max, uchar, unsigned char, uchar4, 4, vmax, max) +VFUNC_FUNC(max, short, short, short4, 4, vmax, max) +VFUNC_FUNC(max, ushort, unsigned short, ushort4, 4, vmax, max) +VFUNC_FUNC(max, int, int, int4, 4, vmax, max) +VFUNC_FUNC(max, uint, unsigned int, uint4, 4, vmax, max) + +#undef current_func +#define current_func(a, b) max(a, b) +FUNC_FUNC(max, long, long) +FUNC_FUNC(max, ulong, unsigned long) +FUNC_FUNC(max, longlong, long long) +FUNC_FUNC(max, ulonglong, unsigned long long) + +/* dispatch fixed-size types */ +OPV_DISPATCH(max, int8_t, int8_t) +OPV_DISPATCH(max, uint8_t, uint8_t) +OPV_DISPATCH(max, int16_t, int16_t) +OPV_DISPATCH(max, uint16_t, uint16_t) +OPV_DISPATCH(max, int32_t, int32_t) +OPV_DISPATCH(max, uint32_t, uint32_t) +OPV_DISPATCH(max, int64_t, int64_t) +OPV_DISPATCH(max, uint64_t, uint64_t) + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +FUNC_FUNC(max, long_double, long double) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmaxf(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(max, float, float) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmax(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(max, double, double) + +// __CUDA_ARCH__ is only defined when compiling device code +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 +#undef current_func +#define current_func(a, b) __hmax2(a, b) +//VFUNC_FUNC(max, halfx, half, half2, 2, __hmax2, __hmax) +#endif // __CUDA_ARCH__ + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(min, char, char, char4, 4, vmin, min) +VFUNC_FUNC(min, uchar, unsigned char, uchar4, 4, vmin, min) +VFUNC_FUNC(min, short, short, short4, 4, vmin, min) +VFUNC_FUNC(min, ushort, unsigned short, ushort4, 4, vmin, min) +VFUNC_FUNC(min, int, int, int4, 4, vmin, min) +VFUNC_FUNC(min, uint, unsigned int, uint4, 4, vmin, min) + +#undef current_func +#define current_func(a, b) min(a, b) +FUNC_FUNC(min, long, long) +FUNC_FUNC(min, ulong, unsigned long) +FUNC_FUNC(min, longlong, long long) +FUNC_FUNC(min, ulonglong, unsigned long long) +OPV_DISPATCH(min, int8_t, int8_t) +OPV_DISPATCH(min, uint8_t, uint8_t) +OPV_DISPATCH(min, int16_t, int16_t) +OPV_DISPATCH(min, uint16_t, uint16_t) +OPV_DISPATCH(min, int32_t, int32_t) +OPV_DISPATCH(min, uint32_t, uint32_t) +OPV_DISPATCH(min, int64_t, int64_t) +OPV_DISPATCH(min, uint64_t, uint64_t) + + + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fminf(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(min, float, float) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmin(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(min, double, double) + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +FUNC_FUNC(min, long_double, long double) + +// __CUDA_ARCH__ is only defined when compiling device code +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 +#undef current_func +#define current_func(a, b) __hmin2(a, b) +//VFUNC_FUNC(min, half, half, half2, 2, __hmin2, __hmin) +#endif // __CUDA_ARCH__ + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(sum, char, char, char4, 4, vsum, tsum) +VFUNC_FUNC(sum, uchar, unsigned char, uchar4, 4, vsum, tsum) +VFUNC_FUNC(sum, short, short, short4, 4, vsum, tsum) +VFUNC_FUNC(sum, ushort, unsigned short, ushort4, 4, vsum, tsum) +VFUNC_FUNC(sum, int, int, int4, 4, vsum, tsum) +VFUNC_FUNC(sum, uint, unsigned int, uint4, 4, vsum, tsum) + +#undef current_func +#define current_func(a, b) tsum(a, b) +FUNC_FUNC(sum, long, long) +FUNC_FUNC(sum, ulong, unsigned long) +FUNC_FUNC(sum, longlong, long long) +FUNC_FUNC(sum, ulonglong, unsigned long long) + +OPV_DISPATCH(sum, int8_t, int8_t) +OPV_DISPATCH(sum, uint8_t, uint8_t) +OPV_DISPATCH(sum, int16_t, int16_t) +OPV_DISPATCH(sum, uint16_t, uint16_t) +OPV_DISPATCH(sum, int32_t, int32_t) +OPV_DISPATCH(sum, uint32_t, uint32_t) +OPV_DISPATCH(sum, int64_t, int64_t) +OPV_DISPATCH(sum, uint64_t, uint64_t) + +OPV_FUNC(sum, float, float, float4, 4, +) +OPV_FUNC(sum, double, double, double4, 4, +) +OP_FUNC(sum, long_double, long double, +) + +// __CUDA_ARCH__ is only defined when compiling device code +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 +#undef current_func +#define current_func(a, b) __hadd2(a, b) +//VFUNC_FUNC(sum, half, half, half2, 2, __hadd2, __hadd) +#endif // __CUDA_ARCH__ + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(sum, c_short_float_complex, short float _Complex, +=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +OP_FUNC(sum, c_long_double_complex, cuLongDoubleComplex, +=) +#endif +#endif // 0 +#undef current_func +#define current_func(a, b) (cuCaddf(a,b)) +FUNC_FUNC(sum, c_float_complex, cuFloatComplex) +#undef current_func +#define current_func(a, b) (cuCadd(a,b)) +FUNC_FUNC(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +#undef current_func +#define current_func(a, b) tprod(a, b) +FUNC_FUNC(prod, char, char) +FUNC_FUNC(prod, uchar, unsigned char) +FUNC_FUNC(prod, short, short) +FUNC_FUNC(prod, ushort, unsigned short) +FUNC_FUNC(prod, int, int) +FUNC_FUNC(prod, uint, unsigned int) +FUNC_FUNC(prod, long, long) +FUNC_FUNC(prod, ulong, unsigned long) +FUNC_FUNC(prod, longlong, long long) +FUNC_FUNC(prod, ulonglong, unsigned long long) + +OPV_DISPATCH(prod, int8_t, int8_t) +OPV_DISPATCH(prod, uint8_t, uint8_t) +OPV_DISPATCH(prod, int16_t, int16_t) +OPV_DISPATCH(prod, uint16_t, uint16_t) +OPV_DISPATCH(prod, int32_t, int32_t) +OPV_DISPATCH(prod, uint32_t, uint32_t) +OPV_DISPATCH(prod, int64_t, int64_t) +OPV_DISPATCH(prod, uint64_t, uint64_t) + + +OPV_FUNC(prod, float, float, float4, 4, *) +OPV_FUNC(prod, double, double, double4, 4, *) +OP_FUNC(prod, long_double, long double, *) + +// __CUDA_ARCH__ is only defined when compiling device code +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 +#undef current_func +#define current_func(a, b) __hmul2(a, b) +//VFUNC_FUNC(prod, half, half, half2, 2, __hmul2, __hmul) +#endif // __CUDA_ARCH__ + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(prod, c_short_float_complex, short float _Complex, *=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(prod, c_long_double_complex, long double _Complex, *=) +#endif // 0 +#undef current_func +#define current_func(a, b) (cuCmulf(a,b)) +FUNC_FUNC(prod, c_float_complex, cuFloatComplex) +#undef current_func +#define current_func(a, b) (cuCmul(a,b)) +FUNC_FUNC(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC(land, int8_t, int8_t) +FUNC_FUNC(land, uint8_t, uint8_t) +FUNC_FUNC(land, int16_t, int16_t) +FUNC_FUNC(land, uint16_t, uint16_t) +FUNC_FUNC(land, int32_t, int32_t) +FUNC_FUNC(land, uint32_t, uint32_t) +FUNC_FUNC(land, int64_t, int64_t) +FUNC_FUNC(land, uint64_t, uint64_t) +FUNC_FUNC(land, long, long) +FUNC_FUNC(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC(lor, int8_t, int8_t) +FUNC_FUNC(lor, uint8_t, uint8_t) +FUNC_FUNC(lor, int16_t, int16_t) +FUNC_FUNC(lor, uint16_t, uint16_t) +FUNC_FUNC(lor, int32_t, int32_t) +FUNC_FUNC(lor, uint32_t, uint32_t) +FUNC_FUNC(lor, int64_t, int64_t) +FUNC_FUNC(lor, uint64_t, uint64_t) +FUNC_FUNC(lor, long, long) +FUNC_FUNC(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC(lxor, int8_t, int8_t) +FUNC_FUNC(lxor, uint8_t, uint8_t) +FUNC_FUNC(lxor, int16_t, int16_t) +FUNC_FUNC(lxor, uint16_t, uint16_t) +FUNC_FUNC(lxor, int32_t, int32_t) +FUNC_FUNC(lxor, uint32_t, uint32_t) +FUNC_FUNC(lxor, int64_t, int64_t) +FUNC_FUNC(lxor, uint64_t, uint64_t) +FUNC_FUNC(lxor, long, long) +FUNC_FUNC(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC(band, int8_t, int8_t) +FUNC_FUNC(band, uint8_t, uint8_t) +FUNC_FUNC(band, int16_t, int16_t) +FUNC_FUNC(band, uint16_t, uint16_t) +FUNC_FUNC(band, int32_t, int32_t) +FUNC_FUNC(band, uint32_t, uint32_t) +FUNC_FUNC(band, int64_t, int64_t) +FUNC_FUNC(band, uint64_t, uint64_t) +FUNC_FUNC(band, long, long) +FUNC_FUNC(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC(bor, int8_t, int8_t) +FUNC_FUNC(bor, uint8_t, uint8_t) +FUNC_FUNC(bor, int16_t, int16_t) +FUNC_FUNC(bor, uint16_t, uint16_t) +FUNC_FUNC(bor, int32_t, int32_t) +FUNC_FUNC(bor, uint32_t, uint32_t) +FUNC_FUNC(bor, int64_t, int64_t) +FUNC_FUNC(bor, uint64_t, uint64_t) +FUNC_FUNC(bor, long, long) +FUNC_FUNC(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC(bxor, int8_t, int8_t) +FUNC_FUNC(bxor, uint8_t, uint8_t) +FUNC_FUNC(bxor, int16_t, int16_t) +FUNC_FUNC(bxor, uint16_t, uint16_t) +FUNC_FUNC(bxor, int32_t, int32_t) +FUNC_FUNC(bxor, uint32_t, uint32_t) +FUNC_FUNC(bxor, int64_t, int64_t) +FUNC_FUNC(bxor, uint64_t, uint64_t) +FUNC_FUNC(bxor, long, long) +FUNC_FUNC(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC(maxloc, float_int, >) +LOC_FUNC(maxloc, double_int, >) +LOC_FUNC(maxloc, long_int, >) +LOC_FUNC(maxloc, 2int, >) +LOC_FUNC(maxloc, short_int, >) +LOC_FUNC(maxloc, long_double_int, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC(minloc, float_int, <) +LOC_FUNC(minloc, double_int, <) +LOC_FUNC(minloc, long_int, <) +LOC_FUNC(minloc, 2int, <) +LOC_FUNC(minloc, short_int, <) +LOC_FUNC(minloc, long_double_int, <) + + +/* + * This is a three buffer (2 input and 1 output) version of the reduction + * routines, needed for some optimizations. + */ +#define OP_FUNC_3BUF(name, type_name, type, op) \ + static __global__ void \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel(const type *__restrict__ in1, \ + const type *__restrict__ in2, \ + type *__restrict__ out, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + out[i] = in1[i] op in2[i]; \ + } \ + } \ + void ompi_op_cuda_3buff_##name##_##type_name##_submit(const type *in1, const type *in2, \ + type *out, int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel<<>>(in1, in2, out, count); \ + } + + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for (out = op(in1, in2)) + */ +#define FUNC_FUNC_3BUF(name, type_name, type) \ + static __global__ void \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel(const type *__restrict__ in1, \ + const type *__restrict__ in2, \ + type *__restrict__ out, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + out[i] = current_func(in1[i], in2[i]); \ + } \ + } \ + void \ + ompi_op_cuda_3buff_##name##_##type_name##_submit(const type *in1, const type *in2, \ + type *out, int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel<<>>(in1, in2, out, count); \ + } + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_FUNC_3BUF(name, type_name, op) \ + static __global__ void \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel(const ompi_op_predefined_##type_name##_t *__restrict__ in1, \ + const ompi_op_predefined_##type_name##_t *__restrict__ in2, \ + ompi_op_predefined_##type_name##_t *__restrict__ out, \ + int n) \ + { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + const ompi_op_predefined_##type_name##_t *a1 = &in1[i]; \ + const ompi_op_predefined_##type_name##_t *a2 = &in2[i]; \ + ompi_op_predefined_##type_name##_t *b = &out[i]; \ + if (a1->v op a2->v) { \ + b->v = a1->v; \ + b->k = a1->k; \ + } else if (a1->v == a2->v) { \ + b->v = a1->v; \ + b->k = (a2->k < a1->k ? a2->k : a1->k); \ + } else { \ + b->v = a2->v; \ + b->k = a2->k; \ + } \ + } \ + } \ + void \ + ompi_op_cuda_3buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *in1, \ + const ompi_op_predefined_##type_name##_t *in2, \ + ompi_op_predefined_##type_name##_t *out, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) \ + { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel<<>>(in1, in2, out, count); \ + } + + +/************************************************************************* + * Max + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(max, int8_t, int8_t) +FUNC_FUNC_3BUF(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF(max, int16_t, int16_t) +FUNC_FUNC_3BUF(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF(max, int32_t, int32_t) +FUNC_FUNC_3BUF(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF(max, int64_t, int64_t) +FUNC_FUNC_3BUF(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF(max, long, long) +FUNC_FUNC_3BUF(max, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(max, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF(max, float, float) +FUNC_FUNC_3BUF(max, double, double) +FUNC_FUNC_3BUF(max, long_double, long double) + + +/************************************************************************* + * Min + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(min, int8_t, int8_t) +FUNC_FUNC_3BUF(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF(min, int16_t, int16_t) +FUNC_FUNC_3BUF(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF(min, int32_t, int32_t) +FUNC_FUNC_3BUF(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF(min, int64_t, int64_t) +FUNC_FUNC_3BUF(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF(min, long, long) +FUNC_FUNC_3BUF(min, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(min, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF(min, float, float) +FUNC_FUNC_3BUF(min, double, double) +FUNC_FUNC_3BUF(min, long_double, long double) + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(sum, int8_t, int8_t, +) +OP_FUNC_3BUF(sum, uint8_t, uint8_t, +) +OP_FUNC_3BUF(sum, int16_t, int16_t, +) +OP_FUNC_3BUF(sum, uint16_t, uint16_t, +) +OP_FUNC_3BUF(sum, int32_t, int32_t, +) +OP_FUNC_3BUF(sum, uint32_t, uint32_t, +) +OP_FUNC_3BUF(sum, int64_t, int64_t, +) +OP_FUNC_3BUF(sum, uint64_t, uint64_t, +) +OP_FUNC_3BUF(sum, long, long, +) +OP_FUNC_3BUF(sum, ulong, unsigned long, +) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF(sum, short_float, short float, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF(sum, short_float, opal_short_float_t, +) +#endif +OP_FUNC_3BUF(sum, float, float, +) +OP_FUNC_3BUF(sum, double, double, +) +OP_FUNC_3BUF(sum, long_double, long double, +) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(sum, c_short_float_complex, short float _Complex, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(sum, c_long_double_complex, cuLongDoubleComplex, +) +#endif // 0 +#undef current_func +#define current_func(a, b) (cuCaddf(a,b)) +FUNC_FUNC_3BUF(sum, c_float_complex, cuFloatComplex) +#undef current_func +#define current_func(a, b) (cuCadd(a,b)) +FUNC_FUNC_3BUF(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(prod, int8_t, int8_t, *) +OP_FUNC_3BUF(prod, uint8_t, uint8_t, *) +OP_FUNC_3BUF(prod, int16_t, int16_t, *) +OP_FUNC_3BUF(prod, uint16_t, uint16_t, *) +OP_FUNC_3BUF(prod, int32_t, int32_t, *) +OP_FUNC_3BUF(prod, uint32_t, uint32_t, *) +OP_FUNC_3BUF(prod, int64_t, int64_t, *) +OP_FUNC_3BUF(prod, uint64_t, uint64_t, *) +OP_FUNC_3BUF(prod, long, long, *) +OP_FUNC_3BUF(prod, ulong, unsigned long, *) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF(prod, short_float, short float, *) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF(prod, short_float, opal_short_float_t, *) +#endif +OP_FUNC_3BUF(prod, float, float, *) +OP_FUNC_3BUF(prod, double, double, *) +OP_FUNC_3BUF(prod, long_double, long double, *) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(prod, c_short_float_complex, short float _Complex, *) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(prod, c_long_double_complex, long double _Complex, *) +#endif // 0 +#undef current_func +#define current_func(a, b) (cuCmulf(a,b)) +FUNC_FUNC_3BUF(prod, c_float_complex, cuFloatComplex) +#undef current_func +#define current_func(a, b) (cuCmul(a,b)) +FUNC_FUNC_3BUF(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC_3BUF(land, int8_t, int8_t) +FUNC_FUNC_3BUF(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF(land, int16_t, int16_t) +FUNC_FUNC_3BUF(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF(land, int32_t, int32_t) +FUNC_FUNC_3BUF(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF(land, int64_t, int64_t) +FUNC_FUNC_3BUF(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF(land, long, long) +FUNC_FUNC_3BUF(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC_3BUF(lor, int8_t, int8_t) +FUNC_FUNC_3BUF(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lor, int16_t, int16_t) +FUNC_FUNC_3BUF(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lor, int32_t, int32_t) +FUNC_FUNC_3BUF(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lor, int64_t, int64_t) +FUNC_FUNC_3BUF(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lor, long, long) +FUNC_FUNC_3BUF(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC_3BUF(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lxor, long, long) +FUNC_FUNC_3BUF(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC_3BUF(band, int8_t, int8_t) +FUNC_FUNC_3BUF(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF(band, int16_t, int16_t) +FUNC_FUNC_3BUF(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF(band, int32_t, int32_t) +FUNC_FUNC_3BUF(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF(band, int64_t, int64_t) +FUNC_FUNC_3BUF(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF(band, long, long) +FUNC_FUNC_3BUF(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC_3BUF(bor, int8_t, int8_t) +FUNC_FUNC_3BUF(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bor, int16_t, int16_t) +FUNC_FUNC_3BUF(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bor, int32_t, int32_t) +FUNC_FUNC_3BUF(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bor, int64_t, int64_t) +FUNC_FUNC_3BUF(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bor, long, long) +FUNC_FUNC_3BUF(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC_3BUF(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bxor, long, long) +FUNC_FUNC_3BUF(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_3BUF(maxloc, float_int, >) +LOC_FUNC_3BUF(maxloc, double_int, >) +LOC_FUNC_3BUF(maxloc, long_int, >) +LOC_FUNC_3BUF(maxloc, 2int, >) +LOC_FUNC_3BUF(maxloc, short_int, >) +LOC_FUNC_3BUF(maxloc, long_double_int, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_3BUF(minloc, float_int, <) +LOC_FUNC_3BUF(minloc, double_int, <) +LOC_FUNC_3BUF(minloc, long_int, <) +LOC_FUNC_3BUF(minloc, 2int, <) +LOC_FUNC_3BUF(minloc, short_int, <) +LOC_FUNC_3BUF(minloc, long_double_int, <) diff --git a/ompi/mca/op/cuda/op_cuda_impl.h b/ompi/mca/op/cuda/op_cuda_impl.h new file mode 100644 index 00000000000..2c02c32c313 --- /dev/null +++ b/ompi/mca/op/cuda/op_cuda_impl.h @@ -0,0 +1,674 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include + +#include +#include +#include + +#ifndef BEGIN_C_DECLS +#if defined(c_plusplus) || defined(__cplusplus) +# define BEGIN_C_DECLS extern "C" { +# define END_C_DECLS } +#else +# define BEGIN_C_DECLS /* empty */ +# define END_C_DECLS /* empty */ +#endif +#endif + +BEGIN_C_DECLS + +#define OP_FUNC_SIG(name, type_name, type, op) \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + +#define FUNC_FUNC_SIG(name, type_name, type) \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_STRUCT(type_name, type1, type2) \ + typedef struct { \ + type1 v; \ + type2 k; \ + } ompi_op_predefined_##type_name##_t; + +#define LOC_FUNC_SIG(name, type_name, op) \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(max, int8_t, int8_t) +FUNC_FUNC_SIG(max, uint8_t, uint8_t) +FUNC_FUNC_SIG(max, int16_t, int16_t) +FUNC_FUNC_SIG(max, uint16_t, uint16_t) +FUNC_FUNC_SIG(max, int32_t, int32_t) +FUNC_FUNC_SIG(max, uint32_t, uint32_t) +FUNC_FUNC_SIG(max, int64_t, int64_t) +FUNC_FUNC_SIG(max, uint64_t, uint64_t) +FUNC_FUNC_SIG(max, long, long) +FUNC_FUNC_SIG(max, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_SIG(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_SIG(max, short_float, opal_short_float_t) +#endif +#endif // 0 + +FUNC_FUNC_SIG(max, float, float) +FUNC_FUNC_SIG(max, double, double) +FUNC_FUNC_SIG(max, long_double, long double) + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(min, int8_t, int8_t) +FUNC_FUNC_SIG(min, uint8_t, uint8_t) +FUNC_FUNC_SIG(min, int16_t, int16_t) +FUNC_FUNC_SIG(min, uint16_t, uint16_t) +FUNC_FUNC_SIG(min, int32_t, int32_t) +FUNC_FUNC_SIG(min, uint32_t, uint32_t) +FUNC_FUNC_SIG(min, int64_t, int64_t) +FUNC_FUNC_SIG(min, uint64_t, uint64_t) +FUNC_FUNC_SIG(min, long, long) +FUNC_FUNC_SIG(min, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_SIG(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_SIG(min, short_float, opal_short_float_t) +#endif +#endif // 0 + +FUNC_FUNC_SIG(min, float, float) +FUNC_FUNC_SIG(min, double, double) +FUNC_FUNC_SIG(min, long_double, long double) + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_SIG(sum, int8_t, int8_t, +=) +OP_FUNC_SIG(sum, uint8_t, uint8_t, +=) +OP_FUNC_SIG(sum, int16_t, int16_t, +=) +OP_FUNC_SIG(sum, uint16_t, uint16_t, +=) +OP_FUNC_SIG(sum, int32_t, int32_t, +=) +OP_FUNC_SIG(sum, uint32_t, uint32_t, +=) +OP_FUNC_SIG(sum, int64_t, int64_t, +=) +OP_FUNC_SIG(sum, uint64_t, uint64_t, +=) +OP_FUNC_SIG(sum, long, long, +=) +OP_FUNC_SIG(sum, ulong, unsigned long, +=) + +#if __CUDA_ARCH__ >= 530 +OP_FUNC_SIG(sum, half, half, +=) +#endif // __CUDA_ARCH__ + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_SIG(sum, short_float, short float, +=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_SIG(sum, short_float, opal_short_float_t, +=) +#endif +#endif // 0 + +OP_FUNC_SIG(sum, float, float, +=) +OP_FUNC_SIG(sum, double, double, +=) +OP_FUNC_SIG(sum, long_double, long double, +=) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_SIG(sum, c_short_float_complex, short float _Complex, +=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +OP_FUNC_SIG(sum, c_long_double_complex, long double _Complex, +=) +#endif +#endif // 0 +FUNC_FUNC_SIG(sum, c_float_complex, cuFloatComplex) +FUNC_FUNC_SIG(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_SIG(prod, int8_t, int8_t, *=) +OP_FUNC_SIG(prod, uint8_t, uint8_t, *=) +OP_FUNC_SIG(prod, int16_t, int16_t, *=) +OP_FUNC_SIG(prod, uint16_t, uint16_t, *=) +OP_FUNC_SIG(prod, int32_t, int32_t, *=) +OP_FUNC_SIG(prod, uint32_t, uint32_t, *=) +OP_FUNC_SIG(prod, int64_t, int64_t, *=) +OP_FUNC_SIG(prod, uint64_t, uint64_t, *=) +OP_FUNC_SIG(prod, long, long, *=) +OP_FUNC_SIG(prod, ulong, unsigned long, *=) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_SIG(prod, short_float, short float, *=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_SIG(prod, short_float, opal_short_float_t, *=) +#endif +#endif // 0 + +OP_FUNC_SIG(prod, float, float, *=) +OP_FUNC_SIG(prod, float, float, *=) +OP_FUNC_SIG(prod, double, double, *=) +OP_FUNC_SIG(prod, long_double, long double, *=) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_SIG(prod, c_short_float_complex, short float _Complex, *=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_SIG(prod, c_long_double_complex, long double _Complex, *=) +#endif // 0 + +FUNC_FUNC_SIG(prod, c_float_complex, cuFloatComplex) +FUNC_FUNC_SIG(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC_SIG(land, int8_t, int8_t) +FUNC_FUNC_SIG(land, uint8_t, uint8_t) +FUNC_FUNC_SIG(land, int16_t, int16_t) +FUNC_FUNC_SIG(land, uint16_t, uint16_t) +FUNC_FUNC_SIG(land, int32_t, int32_t) +FUNC_FUNC_SIG(land, uint32_t, uint32_t) +FUNC_FUNC_SIG(land, int64_t, int64_t) +FUNC_FUNC_SIG(land, uint64_t, uint64_t) +FUNC_FUNC_SIG(land, long, long) +FUNC_FUNC_SIG(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC_SIG(lor, int8_t, int8_t) +FUNC_FUNC_SIG(lor, uint8_t, uint8_t) +FUNC_FUNC_SIG(lor, int16_t, int16_t) +FUNC_FUNC_SIG(lor, uint16_t, uint16_t) +FUNC_FUNC_SIG(lor, int32_t, int32_t) +FUNC_FUNC_SIG(lor, uint32_t, uint32_t) +FUNC_FUNC_SIG(lor, int64_t, int64_t) +FUNC_FUNC_SIG(lor, uint64_t, uint64_t) +FUNC_FUNC_SIG(lor, long, long) +FUNC_FUNC_SIG(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC_SIG(lxor, int8_t, int8_t) +FUNC_FUNC_SIG(lxor, uint8_t, uint8_t) +FUNC_FUNC_SIG(lxor, int16_t, int16_t) +FUNC_FUNC_SIG(lxor, uint16_t, uint16_t) +FUNC_FUNC_SIG(lxor, int32_t, int32_t) +FUNC_FUNC_SIG(lxor, uint32_t, uint32_t) +FUNC_FUNC_SIG(lxor, int64_t, int64_t) +FUNC_FUNC_SIG(lxor, uint64_t, uint64_t) +FUNC_FUNC_SIG(lxor, long, long) +FUNC_FUNC_SIG(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC_SIG(band, int8_t, int8_t) +FUNC_FUNC_SIG(band, uint8_t, uint8_t) +FUNC_FUNC_SIG(band, int16_t, int16_t) +FUNC_FUNC_SIG(band, uint16_t, uint16_t) +FUNC_FUNC_SIG(band, int32_t, int32_t) +FUNC_FUNC_SIG(band, uint32_t, uint32_t) +FUNC_FUNC_SIG(band, int64_t, int64_t) +FUNC_FUNC_SIG(band, uint64_t, uint64_t) +FUNC_FUNC_SIG(band, long, long) +FUNC_FUNC_SIG(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC_SIG(bor, int8_t, int8_t) +FUNC_FUNC_SIG(bor, uint8_t, uint8_t) +FUNC_FUNC_SIG(bor, int16_t, int16_t) +FUNC_FUNC_SIG(bor, uint16_t, uint16_t) +FUNC_FUNC_SIG(bor, int32_t, int32_t) +FUNC_FUNC_SIG(bor, uint32_t, uint32_t) +FUNC_FUNC_SIG(bor, int64_t, int64_t) +FUNC_FUNC_SIG(bor, uint64_t, uint64_t) +FUNC_FUNC_SIG(bor, long, long) +FUNC_FUNC_SIG(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC_SIG(bxor, int8_t, int8_t) +FUNC_FUNC_SIG(bxor, uint8_t, uint8_t) +FUNC_FUNC_SIG(bxor, int16_t, int16_t) +FUNC_FUNC_SIG(bxor, uint16_t, uint16_t) +FUNC_FUNC_SIG(bxor, int32_t, int32_t) +FUNC_FUNC_SIG(bxor, uint32_t, uint32_t) +FUNC_FUNC_SIG(bxor, int64_t, int64_t) +FUNC_FUNC_SIG(bxor, uint64_t, uint64_t) +FUNC_FUNC_SIG(bxor, long, long) +FUNC_FUNC_SIG(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(bxor, byte, char) + +/************************************************************************* + * Min and max location "pair" datatypes + *************************************************************************/ + +LOC_STRUCT(float_int, float, int) +LOC_STRUCT(double_int, double, int) +LOC_STRUCT(long_int, long, int) +LOC_STRUCT(2int, int, int) +LOC_STRUCT(short_int, short, int) +LOC_STRUCT(long_double_int, long double, int) +LOC_STRUCT(ulong, unsigned long, int) +/* compat types for Fortran */ +LOC_STRUCT(2real, float, float) +LOC_STRUCT(2double_precision, double, double) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_SIG(maxloc, float_int, >) +LOC_FUNC_SIG(maxloc, double_int, >) +LOC_FUNC_SIG(maxloc, long_int, >) +LOC_FUNC_SIG(maxloc, 2int, >) +LOC_FUNC_SIG(maxloc, short_int, >) +LOC_FUNC_SIG(maxloc, long_double_int, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_SIG(minloc, float_int, <) +LOC_FUNC_SIG(minloc, double_int, <) +LOC_FUNC_SIG(minloc, long_int, <) +LOC_FUNC_SIG(minloc, 2int, <) +LOC_FUNC_SIG(minloc, short_int, <) +LOC_FUNC_SIG(minloc, long_double_int, <) + + + +#define OP_FUNC_3BUF_SIG(name, type_name, type, op) \ + void ompi_op_cuda_3buff_##name##_##type_name##_submit(const type *in1, \ + const type *in2, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + +#define FUNC_FUNC_3BUF_SIG(name, type_name, type) \ + void ompi_op_cuda_3buff_##name##_##type_name##_submit(const type *in1, \ + const type *in2, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + +#define LOC_FUNC_3BUF_SIG(name, type_name, op) \ + void ompi_op_cuda_3buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a1, \ + const ompi_op_predefined_##type_name##_t *a2, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(max, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(max, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(max, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(max, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(max, long, long) +FUNC_FUNC_3BUF_SIG(max, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF_SIG(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF_SIG(max, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF_SIG(max, float, float) +FUNC_FUNC_3BUF_SIG(max, double, double) +FUNC_FUNC_3BUF_SIG(max, long_double, long double) + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(min, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(min, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(min, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(min, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(min, long, long) +FUNC_FUNC_3BUF_SIG(min, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF_SIG(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF_SIG(min, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF_SIG(min, float, float) +FUNC_FUNC_3BUF_SIG(min, double, double) +FUNC_FUNC_3BUF_SIG(min, long_double, long double) + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF_SIG(sum, int8_t, int8_t, +) +OP_FUNC_3BUF_SIG(sum, uint8_t, uint8_t, +) +OP_FUNC_3BUF_SIG(sum, int16_t, int16_t, +) +OP_FUNC_3BUF_SIG(sum, uint16_t, uint16_t, +) +OP_FUNC_3BUF_SIG(sum, int32_t, int32_t, +) +OP_FUNC_3BUF_SIG(sum, uint32_t, uint32_t, +) +OP_FUNC_3BUF_SIG(sum, int64_t, int64_t, +) +OP_FUNC_3BUF_SIG(sum, uint64_t, uint64_t, +) +OP_FUNC_3BUF_SIG(sum, long, long, +) +OP_FUNC_3BUF_SIG(sum, ulong, unsigned long, +) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF_SIG(sum, short_float, short float, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF_SIG(sum, short_float, opal_short_float_t, +) +#endif +OP_FUNC_3BUF_SIG(sum, float, float, +) +OP_FUNC_3BUF_SIG(sum, double, double, +) +OP_FUNC_3BUF_SIG(sum, long_double, long double, +) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF_SIG(sum, c_short_float_complex, short float _Complex, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF_SIG(sum, c_long_double_complex, long double _Complex, +) +#endif // 0 +FUNC_FUNC_3BUF_SIG(sum, c_float_complex, cuFloatComplex) +FUNC_FUNC_3BUF_SIG(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF_SIG(prod, int8_t, int8_t, *) +OP_FUNC_3BUF_SIG(prod, uint8_t, uint8_t, *) +OP_FUNC_3BUF_SIG(prod, int16_t, int16_t, *) +OP_FUNC_3BUF_SIG(prod, uint16_t, uint16_t, *) +OP_FUNC_3BUF_SIG(prod, int32_t, int32_t, *) +OP_FUNC_3BUF_SIG(prod, uint32_t, uint32_t, *) +OP_FUNC_3BUF_SIG(prod, int64_t, int64_t, *) +OP_FUNC_3BUF_SIG(prod, uint64_t, uint64_t, *) +OP_FUNC_3BUF_SIG(prod, long, long, *) +OP_FUNC_3BUF_SIG(prod, ulong, unsigned long, *) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF_SIG(prod, short_float, short float, *) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF_SIG(prod, short_float, opal_short_float_t, *) +#endif +OP_FUNC_3BUF_SIG(prod, float, float, *) +OP_FUNC_3BUF_SIG(prod, double, double, *) +OP_FUNC_3BUF_SIG(prod, long_double, long double, *) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF_SIG(prod, c_short_float_complex, short float _Complex, *) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF_SIG(prod, c_float_complex, float _Complex, *) +OP_FUNC_3BUF_SIG(prod, c_double_complex, double _Complex, *) +OP_FUNC_3BUF_SIG(prod, c_long_double_complex, long double _Complex, *) +#endif // 0 +FUNC_FUNC_3BUF_SIG(prod, c_float_complex, cuFloatComplex) +FUNC_FUNC_3BUF_SIG(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(land, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(land, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(land, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(land, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(land, long, long) +FUNC_FUNC_3BUF_SIG(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(lor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(lor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(lor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(lor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(lor, long, long) +FUNC_FUNC_3BUF_SIG(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(lxor, long, long) +FUNC_FUNC_3BUF_SIG(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(band, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(band, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(band, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(band, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(band, long, long) +FUNC_FUNC_3BUF_SIG(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF_SIG(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(bor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(bor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(bor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(bor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(bor, long, long) +FUNC_FUNC_3BUF_SIG(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF_SIG(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(bxor, long, long) +FUNC_FUNC_3BUF_SIG(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF_SIG(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_3BUF_SIG(maxloc, float_int, >) +LOC_FUNC_3BUF_SIG(maxloc, double_int, >) +LOC_FUNC_3BUF_SIG(maxloc, long_int, >) +LOC_FUNC_3BUF_SIG(maxloc, 2int, >) +LOC_FUNC_3BUF_SIG(maxloc, short_int, >) +LOC_FUNC_3BUF_SIG(maxloc, long_double_int, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_3BUF_SIG(minloc, float_int, <) +LOC_FUNC_3BUF_SIG(minloc, double_int, <) +LOC_FUNC_3BUF_SIG(minloc, long_int, <) +LOC_FUNC_3BUF_SIG(minloc, 2int, <) +LOC_FUNC_3BUF_SIG(minloc, short_int, <) +LOC_FUNC_3BUF_SIG(minloc, long_double_int, <) + +END_C_DECLS diff --git a/ompi/mca/op/op.h b/ompi/mca/op/op.h index 34d26376ab9..097c2a109b4 100644 --- a/ompi/mca/op/op.h +++ b/ompi/mca/op/op.h @@ -85,6 +85,7 @@ #include "ompi_config.h" #include "opal/class/opal_object.h" +#include "opal/mca/accelerator/accelerator.h" #include "ompi/mca/mca.h" /* @@ -266,6 +267,22 @@ typedef void (*ompi_op_base_handler_fn_1_0_0_t)(const void *, void *, int *, typedef ompi_op_base_handler_fn_1_0_0_t ompi_op_base_handler_fn_t; +/** + * Typedef for 2-buffer op functions on streams/devices. + * + * We don't use MPI_User_function because this would create a + * confusing dependency loop between this file and mpi.h. So this is + * repeated code, but it's better this way (and this typedef will + * never change, so there's not much of a maintenance worry). + */ +typedef void (*ompi_op_base_stream_handler_fn_1_0_0_t)(const void *, void *, int *, + struct ompi_datatype_t **, + int device, + opal_accelerator_stream_t *stream, + struct ompi_op_base_module_1_0_0_t *); + +typedef ompi_op_base_stream_handler_fn_1_0_0_t ompi_op_base_stream_handler_fn_t; + /* * Typedef for 3-buffer (two input and one output) op functions. */ @@ -277,6 +294,19 @@ typedef void (*ompi_op_base_3buff_handler_fn_1_0_0_t)(const void *, typedef ompi_op_base_3buff_handler_fn_1_0_0_t ompi_op_base_3buff_handler_fn_t; +/* + * Typedef for 3-buffer (two input and one output) op functions on streams. + */ +typedef void (*ompi_op_base_3buff_stream_handler_fn_1_0_0_t)(const void *, + const void *, + void *, int *, + struct ompi_datatype_t **, + int device, + opal_accelerator_stream_t*, + struct ompi_op_base_module_1_0_0_t *); + +typedef ompi_op_base_3buff_stream_handler_fn_1_0_0_t ompi_op_base_3buff_stream_handler_fn_t; + /** * Op component initialization * @@ -376,10 +406,18 @@ typedef struct ompi_op_base_module_1_0_0_t { is being used for */ struct ompi_op_t *opm_op; + bool opm_device_enabled; + /** Function pointers for all the different datatypes to be used with the MPI_Op that this module is used with */ - ompi_op_base_handler_fn_1_0_0_t opm_fns[OMPI_OP_BASE_TYPE_MAX]; - ompi_op_base_3buff_handler_fn_1_0_0_t opm_3buff_fns[OMPI_OP_BASE_TYPE_MAX]; + union { + ompi_op_base_handler_fn_1_0_0_t opm_fns[OMPI_OP_BASE_TYPE_MAX]; + ompi_op_base_stream_handler_fn_1_0_0_t opm_stream_fns[OMPI_OP_BASE_TYPE_MAX]; + }; + union { + ompi_op_base_3buff_handler_fn_1_0_0_t opm_3buff_fns[OMPI_OP_BASE_TYPE_MAX]; + ompi_op_base_3buff_stream_handler_fn_1_0_0_t opm_3buff_stream_fns[OMPI_OP_BASE_TYPE_MAX]; + }; } ompi_op_base_module_1_0_0_t; /** @@ -404,6 +442,18 @@ typedef struct ompi_op_base_op_fns_1_0_0_t { typedef ompi_op_base_op_fns_1_0_0_t ompi_op_base_op_fns_t; +/** + * Struct that is used in op.h to hold all the function pointers and + * pointers to the corresopnding modules (so that we can properly + * RETAIN/RELEASE them) + */ +typedef struct ompi_op_base_op_stream_fns_1_0_0_t { + ompi_op_base_stream_handler_fn_1_0_0_t fns[OMPI_OP_BASE_TYPE_MAX]; + ompi_op_base_module_t *modules[OMPI_OP_BASE_TYPE_MAX]; +} ompi_op_base_op_stream_fns_1_0_0_t; + +typedef ompi_op_base_op_stream_fns_1_0_0_t ompi_op_base_op_stream_fns_t; + /** * Struct that is used in op.h to hold all the function pointers and * pointers to the corresopnding modules (so that we can properly @@ -416,6 +466,18 @@ typedef struct ompi_op_base_op_3buff_fns_1_0_0_t { typedef ompi_op_base_op_3buff_fns_1_0_0_t ompi_op_base_op_3buff_fns_t; +/** + * Struct that is used in op.h to hold all the function pointers and + * pointers to the corresopnding modules (so that we can properly + * RETAIN/RELEASE them) + */ +typedef struct ompi_op_base_op_3buff_stream_fns_1_0_0_t { + ompi_op_base_3buff_stream_handler_fn_1_0_0_t fns[OMPI_OP_BASE_TYPE_MAX]; + ompi_op_base_module_t *modules[OMPI_OP_BASE_TYPE_MAX]; +} ompi_op_base_op_3buff_stream_fns_1_0_0_t; + +typedef ompi_op_base_op_3buff_stream_fns_1_0_0_t ompi_op_base_op_3buff_stream_fns_t; + /* * Macro for use in modules that are of type op v2.0.0 */ diff --git a/ompi/mca/op/rocm/Makefile.am b/ompi/mca/op/rocm/Makefile.am new file mode 100644 index 00000000000..1b79e890f72 --- /dev/null +++ b/ompi/mca/op/rocm/Makefile.am @@ -0,0 +1,82 @@ +# +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +# This component provides support for offloading reduce ops to ROCM devices. +# +# See https://github.com/open-mpi/ompi/wiki/devel-CreateComponent +# for more details on how to make Open MPI components. + +# First, list all .h and .c sources. It is necessary to list all .h +# files so that they will be picked up in the distribution tarball. + +AM_CPPFLAGS = $(op_rocm_CPPFLAGS) + +dist_ompidata_DATA = help-ompi-mca-op-rocm.txt + +sources = op_rocm_component.c op_rocm.h op_rocm_functions.c op_rocm_impl.h +rocm_sources = op_rocm_impl.cpp + +HIPCC = hipcc + +.cpp.l$(OBJEXT): + $(LIBTOOL) $(AM_V_lt) --tag=CC $(AM_LIBTOOLFLAGS) \ + $(LIBTOOLFLAGS) --mode=compile $(HIPCC) -O2 -std=c++17 -fvectorize -prefer-non-pic -Wc,-fPIC,-g -c $< + +# -o $($@.o:.lo) + +# Open MPI components can be compiled two ways: +# +# 1. As a standalone dynamic shared object (DSO), sometimes called a +# dynamically loadable library (DLL). +# +# 2. As a static library that is slurped up into the upper-level +# libmpi library (regardless of whether libmpi is a static or dynamic +# library). This is called a "Libtool convenience library". +# +# The component needs to create an output library in this top-level +# component directory, and named either mca__.la (for DSO +# builds) or libmca__.la (for static builds). The OMPI +# build system will have set the +# MCA_BUILD_ompi___DSO AM_CONDITIONAL to indicate +# which way this component should be built. + +if MCA_BUILD_ompi_op_rocm_DSO +component_install = mca_op_rocm.la +else +component_install = +component_noinst = libmca_op_rocm.la +endif + +# Specific information for DSO builds. +# +# The DSO should install itself in $(ompilibdir) (by default, +# $prefix/lib/openmpi). + +mcacomponentdir = $(ompilibdir) +mcacomponent_LTLIBRARIES = $(component_install) +mca_op_rocm_la_SOURCES = $(sources) +mca_op_rocm_la_LIBADD = $(rocm_sources:.cpp=.lo) +mca_op_rocm_la_LDFLAGS = -module -avoid-version $(top_builddir)/ompi/lib@OMPI_LIBMPI_NAME@.la \ + $(op_rocm_LIBS) +EXTRA_mca_op_rocm_la_SOURCES = $(rocm_sources) + +# Specific information for static builds. +# +# Note that we *must* "noinst"; the upper-layer Makefile.am's will +# slurp in the resulting .la library into libmpi. + +noinst_LTLIBRARIES = $(component_noinst) +libmca_op_rocm_la_SOURCES = $(sources) +libmca_op_rocm_la_LIBADD = $(rocm_sources:.cpp=.lo) +libmca_op_rocm_la_LDFLAGS = -module -avoid-version\ + $(op_rocm_LIBS) +EXTRA_libmca_op_rocm_la_SOURCES = $(rocm_sources) + diff --git a/ompi/mca/op/rocm/configure.m4 b/ompi/mca/op/rocm/configure.m4 new file mode 100644 index 00000000000..ffd88698be0 --- /dev/null +++ b/ompi/mca/op/rocm/configure.m4 @@ -0,0 +1,36 @@ +# -*- shell-script -*- +# +# Copyright (c) 2011-2013 NVIDIA Corporation. All rights reserved. +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# Copyright (c) 2022 Amazon.com, Inc. or its affiliates. +# All Rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +# +# If ROCm support was requested, then build the ROCm support library. +# This code checks makes sure the check was done earlier by the +# opal_check_rocm.m4 code. It also copies the flags and libs under +# opal_rocm_CPPFLAGS, opal_rocm_LDFLAGS, and opal_rocm_LIBS + +AC_DEFUN([MCA_ompi_op_rocm_CONFIG],[ + + AC_CONFIG_FILES([ompi/mca/op/rocm/Makefile]) + + OPAL_CHECK_ROCM([op_rocm]) + + AS_IF([test "x$ROCM_SUPPORT" = "x1"], + [$1], + [$2]) + + AC_SUBST([op_rocm_CPPFLAGS]) + AC_SUBST([op_rocm_LDFLAGS]) + AC_SUBST([op_rocm_LIBS]) + +])dnl diff --git a/ompi/mca/op/rocm/help-ompi-mca-op-rocm.txt b/ompi/mca/op/rocm/help-ompi-mca-op-rocm.txt new file mode 100644 index 00000000000..848afbb663d --- /dev/null +++ b/ompi/mca/op/rocm/help-ompi-mca-op-rocm.txt @@ -0,0 +1,15 @@ +# -*- text -*- +# +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# +# This is the US/English help file for Open MPI's HIP operator component +# +[HIP call failed] +"HIP call %s failed: %s: %s\n" diff --git a/ompi/mca/op/rocm/op_rocm.h b/ompi/mca/op/rocm/op_rocm.h new file mode 100644 index 00000000000..0dfeabf689b --- /dev/null +++ b/ompi/mca/op/rocm/op_rocm.h @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#ifndef MCA_OP_CUDA_EXPORT_H +#define MCA_OP_CUDA_EXPORT_H + +#include "ompi_config.h" + +#include "ompi/mca/mca.h" +#include "opal/class/opal_object.h" + +#include "ompi/mca/op/op.h" +#include "ompi/runtime/mpiruntime.h" + +#include +#include + +BEGIN_C_DECLS + + +#define xstr(x) #x +#define str(x) xstr(x) + +#define CHECK(fn, args) \ + do { \ + hipError_t err = fn args; \ + if (err != hipSuccess) { \ + opal_show_help("help-ompi-mca-op-rocm.txt", \ + "HIP call failed", true, \ + str(fn), hipGetErrorName(err), \ + hipGetErrorString(err)); \ + ompi_mpi_abort(MPI_COMM_WORLD, 1); \ + } \ + } while (0) + + +/** + * Derive a struct from the base op component struct, allowing us to + * cache some component-specific information on our well-known + * component struct. + */ +typedef struct { + /** The base op component struct */ + ompi_op_base_component_1_0_0_t super; + +#if 0 + /* a stream on which to schedule kernel calls */ + hipStream_t ro_stream; + hipCtx_t *ro_ctx; +#endif // 0 + int ro_max_num_blocks; + int ro_max_num_threads; + int *ro_max_threads_per_block; + int *ro_max_blocks; + hipDevice_t *ro_devices; + int ro_num_devices; +} ompi_op_rocm_component_t; + +/** + * Globally exported variable. Note that it is a *rocm* component + * (defined above), which has the ompi_op_base_component_t as its + * first member. Hence, the MCA/op framework will find the data that + * it expects in the first memory locations, but then the component + * itself can cache additional information after that that can be used + * by both the component and modules. + */ +OMPI_DECLSPEC extern ompi_op_rocm_component_t + mca_op_rocm_component; + +OMPI_DECLSPEC extern +ompi_op_base_stream_handler_fn_t ompi_op_rocm_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX]; + +OMPI_DECLSPEC extern +ompi_op_base_3buff_stream_handler_fn_t ompi_op_rocm_3buff_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX]; + +END_C_DECLS + +#endif /* MCA_OP_CUDA_EXPORT_H */ diff --git a/ompi/mca/op/rocm/op_rocm_component.c b/ompi/mca/op/rocm/op_rocm_component.c new file mode 100644 index 00000000000..e363bf94385 --- /dev/null +++ b/ompi/mca/op/rocm/op_rocm_component.c @@ -0,0 +1,219 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * Copyright (c) 2021 Cisco Systems, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +/** @file + * + * This is the "rocm" op component source code. + * + */ + +#include "ompi_config.h" + +#include "opal/util/printf.h" + +#include "ompi/constants.h" +#include "ompi/op/op.h" +#include "ompi/mca/op/op.h" +#include "ompi/mca/op/base/base.h" +#include "ompi/mca/op/rocm/op_rocm.h" + +#include + +static int rocm_component_open(void); +static int rocm_component_close(void); +static int rocm_component_init_query(bool enable_progress_threads, + bool enable_mpi_thread_multiple); +static struct ompi_op_base_module_1_0_0_t * + rocm_component_op_query(struct ompi_op_t *op, int *priority); +static int rocm_component_register(void); + +ompi_op_rocm_component_t mca_op_rocm_component = { + { + .opc_version = { + OMPI_OP_BASE_VERSION_1_0_0, + + .mca_component_name = "rocm", + MCA_BASE_MAKE_VERSION(component, OMPI_MAJOR_VERSION, OMPI_MINOR_VERSION, + OMPI_RELEASE_VERSION), + .mca_open_component = rocm_component_open, + .mca_close_component = rocm_component_close, + .mca_register_component_params = rocm_component_register, + }, + .opc_data = { + /* The component is checkpoint ready */ + MCA_BASE_METADATA_PARAM_CHECKPOINT + }, + + .opc_init_query = rocm_component_init_query, + .opc_op_query = rocm_component_op_query, + }, + .ro_max_num_blocks = -1, + .ro_max_num_threads = -1, + .ro_max_threads_per_block = NULL, + .ro_max_blocks = NULL, + .ro_devices = NULL, + .ro_num_devices = 0, +}; + +/* + * Component open + */ +static int rocm_component_open(void) +{ + /* We checked the flags during register, so if they are set to + * zero either the architecture is not suitable or the user disabled + * AVX support. + * + * A first level check to see what level of AVX is available on the + * hardware. + * + * Note that if this function returns non-OMPI_SUCCESS, then this + * component won't even be shown in ompi_info output (which is + * probably not what you want). + */ + return OMPI_SUCCESS; +} + +/* + * Component close + */ +static int rocm_component_close(void) +{ + if (mca_op_rocm_component.ro_num_devices > 0) { + //hipStreamDestroy(mca_op_rocm_component.ro_stream); + free(mca_op_rocm_component.ro_max_threads_per_block); + mca_op_rocm_component.ro_max_threads_per_block = NULL; + free(mca_op_rocm_component.ro_max_blocks); + mca_op_rocm_component.ro_max_blocks = NULL; + free(mca_op_rocm_component.ro_devices); + mca_op_rocm_component.ro_devices = NULL; + mca_op_rocm_component.ro_num_devices = 0; + } + + return OMPI_SUCCESS; +} + +/* + * Register MCA params. + */ +static int +rocm_component_register(void) +{ + /* TODO: add mca paramters */ + + mca_base_var_enum_flag_t *new_enum_flag = NULL; + (void) mca_base_component_var_register(&mca_op_rocm_component.super.opc_version, + "max_num_blocks", + "Maximum number of thread blocks in kernels (-1: device limit)", + MCA_BASE_VAR_TYPE_INT, + &(new_enum_flag->super), 0, 0, + OPAL_INFO_LVL_4, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_op_rocm_component.ro_max_num_blocks); + + (void) mca_base_component_var_register(&mca_op_rocm_component.super.opc_version, + "max_num_threads", + "Maximum number of threads per block in kernels (-1: device limit)", + MCA_BASE_VAR_TYPE_INT, + &(new_enum_flag->super), 0, 0, + OPAL_INFO_LVL_4, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_op_rocm_component.ro_max_num_threads); + + return OMPI_SUCCESS; +} + + +/* + * Query whether this component wants to be used in this process. + */ +static int +rocm_component_init_query(bool enable_progress_threads, + bool enable_mpi_thread_multiple) +{ + int num_devices; + int rc; + int prio_lo, prio_hi; + //memset(&mca_op_rocm_component, 0, sizeof(mca_op_rocm_component)); + hipInit(0); + CHECK(hipGetDeviceCount, (&num_devices)); + mca_op_rocm_component.ro_num_devices = num_devices; + mca_op_rocm_component.ro_devices = (hipDevice_t*)malloc(num_devices*sizeof(hipDevice_t)); + mca_op_rocm_component.ro_max_threads_per_block = (int*)malloc(num_devices*sizeof(int)); + mca_op_rocm_component.ro_max_blocks = (int*)malloc(num_devices*sizeof(int)); + for (int i = 0; i < num_devices; ++i) { + CHECK(hipDeviceGet, (&mca_op_rocm_component.ro_devices[i], i)); + rc = hipDeviceGetAttribute(&mca_op_rocm_component.ro_max_threads_per_block[i], + hipDeviceAttributeMaxBlockDimX, + mca_op_rocm_component.ro_devices[i]); + if (hipSuccess != rc) { + /* fall-back to value that should work on every device */ + mca_op_rocm_component.ro_max_threads_per_block[i] = 512; + } + if (-1 < mca_op_rocm_component.ro_max_num_threads) { + if (mca_op_rocm_component.ro_max_threads_per_block[i] > mca_op_rocm_component.ro_max_num_threads) { + mca_op_rocm_component.ro_max_threads_per_block[i] = mca_op_rocm_component.ro_max_num_threads; + } + } + + rc = hipDeviceGetAttribute(&mca_op_rocm_component.ro_max_blocks[i], + hipDeviceAttributeMaxGridDimX, + mca_op_rocm_component.ro_devices[i]); + if (hipSuccess != rc) { + /* we'll try to max out the blocks */ + mca_op_rocm_component.ro_max_blocks[i] = 512; + } + if (-1 < mca_op_rocm_component.ro_max_num_blocks) { + if (mca_op_rocm_component.ro_max_blocks[i] > mca_op_rocm_component.ro_max_num_blocks) { + mca_op_rocm_component.ro_max_blocks[i] = mca_op_rocm_component.ro_max_num_blocks; + } + } + } + +#if 0 + /* try to create a high-priority stream */ + rc = hipDeviceGetStreamPriorityRange(&prio_lo, &prio_hi); + if (hipSuccess != rc) { + hipStreamCreateWithPriority(&mca_op_rocm_component.ro_stream, hipStreamNonBlocking, prio_hi); + } else { + mca_op_rocm_component.ro_stream = 0; + } +#endif // 0 + return OMPI_SUCCESS; +} + +/* + * Query whether this component can be used for a specific op + */ +static struct ompi_op_base_module_1_0_0_t* +rocm_component_op_query(struct ompi_op_t *op, int *priority) +{ + ompi_op_base_module_t *module = NULL; + + module = OBJ_NEW(ompi_op_base_module_t); + module->opm_device_enabled = true; + for (int i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { + module->opm_stream_fns[i] = ompi_op_rocm_functions[op->o_f_to_c_index][i]; + module->opm_3buff_stream_fns[i] = ompi_op_rocm_3buff_functions[op->o_f_to_c_index][i]; + + if( NULL != module->opm_fns[i] ) { + OBJ_RETAIN(module); + } + if( NULL != module->opm_3buff_fns[i] ) { + OBJ_RETAIN(module); + } + } + *priority = 50; + return (ompi_op_base_module_1_0_0_t *) module; +} diff --git a/ompi/mca/op/rocm/op_rocm_functions.c b/ompi/mca/op/rocm/op_rocm_functions.c new file mode 100644 index 00000000000..43420dc18a7 --- /dev/null +++ b/ompi/mca/op/rocm/op_rocm_functions.c @@ -0,0 +1,1722 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#ifdef HAVE_SYS_TYPES_H +#include +#endif +#include "opal/util/output.h" + + +#include "ompi/op/op.h" +#include "ompi/mca/op/op.h" +#include "ompi/mca/op/base/base.h" +#include "ompi/mca/op/rocm/op_rocm.h" +#include "opal/mca/accelerator/accelerator.h" + +#include "ompi/mca/op/rocm/op_rocm.h" +#include "ompi/mca/op/rocm/op_rocm_impl.h" + + +static inline void device_op_pre(const void *orig_source1, + void **source1, + int *source1_device, + const void *orig_source2, + void **source2, + int *source2_device, + void *orig_target, + void **target, + int *target_device, + int count, + struct ompi_datatype_t *dtype, + int *threads_per_block, + int *max_blocks, + int *device, + opal_accelerator_stream_t *stream) +{ + uint64_t target_flags = -1, source1_flags = -1, source2_flags = -1; + int target_rc, source1_rc, source2_rc = -1; + + *target = orig_target; + *source1 = (void*)orig_source1; + if (NULL != orig_source2) { + *source2 = (void*)orig_source2; + } + + if (*device != MCA_ACCELERATOR_NO_DEVICE_ID) { + /* we got the device from the caller, just adjust the output parameters */ + *target_device = *device; + *source1_device = *device; + if (NULL != source2_device) { + *source2_device = *device; + } + } else { + + target_rc = opal_accelerator.check_addr(*target, target_device, &target_flags); + source1_rc = opal_accelerator.check_addr(*source1, source1_device, &source1_flags); + *device = *target_device; + + if (NULL != orig_source2) { + source2_rc = opal_accelerator.check_addr(*source2, source2_device, &source2_flags); + } + + if (0 == target_rc && 0 == source1_rc && 0 == source2_rc) { + /* no buffers are on any device, select device 0 */ + *device = 0; + } else if (*target_device == -1) { + if (*source1_device == -1 && NULL != orig_source2) { + *device = *source2_device; + } else { + *device = *source1_device; + } + } + + if (0 == target_rc || 0 == source1_rc || *target_device != *source1_device) { + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + if (0 == target_rc) { + // allocate memory on the device for the target buffer + opal_accelerator.mem_alloc_stream(*device, target, nbytes, stream); + CHECK(hipMemcpyHtoDAsync, ((hipDeviceptr_t)*target, orig_target, nbytes, *(hipStream_t*)stream->stream)); + *target_device = -1; // mark target device as host + } + + if (0 == source1_rc || *device != *source1_device) { + // allocate memory on the device for the source buffer + opal_accelerator.mem_alloc_stream(*device, source1, nbytes, stream); + if (0 == source1_rc) { + /* copy from host to device */ + CHECK(hipMemcpyHtoDAsync, ((hipDeviceptr_t)*source1, (void*)orig_source1, nbytes, *(hipStream_t*)stream->stream)); + } else { + /* copy from one device to another device */ + /* TODO: does this actually work? Can we enable P2P? */ + CHECK(hipMemcpyDtoDAsync, ((hipDeviceptr_t)*source1, (hipDeviceptr_t)orig_source1, nbytes, *(hipStream_t*)stream->stream)); + } + } + + } + if (NULL != source2_device && *target_device != *source2_device) { + // allocate memory on the device for the source buffer + //printf("allocating source on device %d\n", *device); + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + opal_accelerator.mem_alloc_stream(*device, source2, nbytes, stream); + if (0 == source2_rc) { + /* copy from host to device */ + //printf("copying source from host to device %d\n", *device); + CHECK(hipMemcpyHtoDAsync, ((hipDeviceptr_t)*source2, (void*)orig_source2, nbytes, *(hipStream_t*)stream->stream)); + } else { + /* copy from one device to another device */ + /* TODO: does this actually work? Can we enable P2P? */ + //printf("attempting cross-device copy for source\n"); + CHECK(hipMemcpyDtoDAsync, ((hipDeviceptr_t)*source2, (hipDeviceptr_t)orig_source2, nbytes, *(hipStream_t*)stream->stream)); + } + } + } + + *threads_per_block = mca_op_rocm_component.ro_max_threads_per_block[*device]; + *max_blocks = mca_op_rocm_component.ro_max_blocks[*device]; + +} + +static inline void device_op_post(void *source1, + int source1_device, + void *source2, + int source2_device, + void *orig_target, + void *target, + int target_device, + int count, + struct ompi_datatype_t *dtype, + int device, + opal_accelerator_stream_t *stream) +{ + if (MCA_ACCELERATOR_NO_DEVICE_ID == target_device) { + + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + CHECK(hipMemcpyDtoHAsync, (orig_target, (hipDeviceptr_t)target, nbytes, *(hipStream_t *)stream->stream)); + } + + if (MCA_ACCELERATOR_NO_DEVICE_ID == target_device) { + opal_accelerator.mem_release_stream(device, target, stream); + } + if (source1_device != device) { + opal_accelerator.mem_release_stream(device, source1, stream); + } + if (NULL != source2 && source2_device != device) { + opal_accelerator.mem_release_stream(device, source2, stream); + } +} + +#define FUNC(name, type_name, type) \ + static \ + void ompi_op_rocm_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + int threads_per_block, max_blocks; \ + int source_device, target_device; \ + type *source, *target; \ + int n = *count; \ + device_op_pre(in, (void**)&source, &source_device, NULL, NULL, NULL, \ + inout, (void**)&target, &target_device, \ + n, *dtype, \ + &threads_per_block, &max_blocks, &device, stream); \ + hipStream_t *custream = (hipStream_t*)stream->stream; \ + ompi_op_rocm_2buff_##name##_##type_name##_submit(source, target, n, threads_per_block, max_blocks, *custream);\ + device_op_post(source, source_device, NULL, -1, inout, target, target_device, n, *dtype, device, stream); \ + } + +#define OP_FUNC(name, type_name, type, op, ...) FUNC(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* reuse the macro above, no work is actually done so we don't care about the func */ +#define FUNC_FUNC(name, type_name, type, ...) FUNC(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_FUNC(name, type_name, op) FUNC(name, type_name, ompi_op_predefined_##type_name##_t) + +/* Dispatch Fortran types to C types */ +#define FORT_INT_FUNC(name, type_name, type) \ + static \ + void ompi_op_rocm_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t)); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_rocm_2buff_##name##_int8_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_rocm_2buff_##name##_int16_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_rocm_2buff_##name##_int32_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_rocm_2buff_##name##_int64_t(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +/* Dispatch Fortran types to C types */ +#define FORT_FLOAT_FUNC(name, type_name, type) \ + static \ + void ompi_op_rocm_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(long double)); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_rocm_2buff_##name##_float(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_rocm_2buff_##name##_double(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(long double): \ + ompi_op_rocm_2buff_##name##_long_double(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +/************************************************************************* + * Max + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC(max, int8_t, int8_t) +FUNC_FUNC(max, uint8_t, uint8_t) +FUNC_FUNC(max, int16_t, int16_t) +FUNC_FUNC(max, uint16_t, uint16_t) +FUNC_FUNC(max, int32_t, int32_t) +FUNC_FUNC(max, uint32_t, uint32_t) +FUNC_FUNC(max, int64_t, int64_t) +FUNC_FUNC(max, uint64_t, uint64_t) +FUNC_FUNC(max, long, long) +FUNC_FUNC(max, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(max, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(max, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(max, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(max, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(max, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(max, fortran_integer16, ompi_fortran_integer16_t) +#endif + +FUNC_FUNC(max, float, float) +FUNC_FUNC(max, double, double) +FUNC_FUNC(max, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(max, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(max, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(max, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(max, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(max, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(max, fortran_real16, ompi_fortran_real16_t) +#endif + + +/************************************************************************* + * Min + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC(min, int8_t, int8_t) +FUNC_FUNC(min, uint8_t, uint8_t) +FUNC_FUNC(min, int16_t, int16_t) +FUNC_FUNC(min, uint16_t, uint16_t) +FUNC_FUNC(min, int32_t, int32_t) +FUNC_FUNC(min, uint32_t, uint32_t) +FUNC_FUNC(min, int64_t, int64_t) +FUNC_FUNC(min, uint64_t, uint64_t) +FUNC_FUNC(min, long, long) +FUNC_FUNC(min, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(min, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(min, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(min, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(min, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(min, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(min, fortran_integer16, ompi_fortran_integer16_t) +#endif + +FUNC_FUNC(min, float, float) +FUNC_FUNC(min, double, double) +FUNC_FUNC(min, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(min, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(min, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(min, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(min, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(min, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(min, fortran_real16, ompi_fortran_real16_t) +#endif + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC(sum, int8_t, int8_t, +=) +OP_FUNC(sum, uint8_t, uint8_t, +=) +OP_FUNC(sum, int16_t, int16_t, +=) +OP_FUNC(sum, uint16_t, uint16_t, +=) +OP_FUNC(sum, int32_t, int32_t, +=) +OP_FUNC(sum, uint32_t, uint32_t, +=) +OP_FUNC(sum, int64_t, int64_t, +=) +OP_FUNC(sum, uint64_t, uint64_t, +=) +OP_FUNC(sum, long, long, +=) +OP_FUNC(sum, ulong, unsigned long, +=) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(sum, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(sum, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(sum, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(sum, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(sum, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(sum, fortran_integer16, ompi_fortran_integer16_t) +#endif + +OP_FUNC(sum, float, float, +=) +OP_FUNC(sum, double, double, +=) +OP_FUNC(sum, long_double, long double, +=) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(sum, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(sum, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(sum, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(sum, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(sum, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(sum, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(sum, c_short_float_complex, short float _Complex, +=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(sum, c_long_double_complex, long double _Complex, +=) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCaddf(a,b)) +FUNC_FUNC(sum, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCadd(a,b)) +FUNC_FUNC(sum, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC(prod, int8_t, int8_t, *=) +OP_FUNC(prod, uint8_t, uint8_t, *=) +OP_FUNC(prod, int16_t, int16_t, *=) +OP_FUNC(prod, uint16_t, uint16_t, *=) +OP_FUNC(prod, int32_t, int32_t, *=) +OP_FUNC(prod, uint32_t, uint32_t, *=) +OP_FUNC(prod, int64_t, int64_t, *=) +OP_FUNC(prod, uint64_t, uint64_t, *=) +OP_FUNC(prod, long, long, *=) +OP_FUNC(prod, ulong, unsigned long, *=) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(prod, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(prod, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(prod, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(prod, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(prod, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(prod, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ + +OP_FUNC(prod, float, float, *=) +OP_FUNC(prod, double, double, *=) +OP_FUNC(prod, long_double, long double, *=) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(prod, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(prod, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(prod, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(prod, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(prod, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(prod, fortran_real16, ompi_fortran_real16_t) +#endif + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(prod, c_short_float_complex, short float _Complex, *=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(prod, c_long_double_complex, long double _Complex, *=) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCmulf(a,b)) +FUNC_FUNC(prod, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCmul(a,b)) +FUNC_FUNC(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC(land, int8_t, int8_t) +FUNC_FUNC(land, uint8_t, uint8_t) +FUNC_FUNC(land, int16_t, int16_t) +FUNC_FUNC(land, uint16_t, uint16_t) +FUNC_FUNC(land, int32_t, int32_t) +FUNC_FUNC(land, uint32_t, uint32_t) +FUNC_FUNC(land, int64_t, int64_t) +FUNC_FUNC(land, uint64_t, uint64_t) +FUNC_FUNC(land, long, long) +FUNC_FUNC(land, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(land, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC(lor, int8_t, int8_t) +FUNC_FUNC(lor, uint8_t, uint8_t) +FUNC_FUNC(lor, int16_t, int16_t) +FUNC_FUNC(lor, uint16_t, uint16_t) +FUNC_FUNC(lor, int32_t, int32_t) +FUNC_FUNC(lor, uint32_t, uint32_t) +FUNC_FUNC(lor, int64_t, int64_t) +FUNC_FUNC(lor, uint64_t, uint64_t) +FUNC_FUNC(lor, long, long) +FUNC_FUNC(lor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(lor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC(lxor, int8_t, int8_t) +FUNC_FUNC(lxor, uint8_t, uint8_t) +FUNC_FUNC(lxor, int16_t, int16_t) +FUNC_FUNC(lxor, uint16_t, uint16_t) +FUNC_FUNC(lxor, int32_t, int32_t) +FUNC_FUNC(lxor, uint32_t, uint32_t) +FUNC_FUNC(lxor, int64_t, int64_t) +FUNC_FUNC(lxor, uint64_t, uint64_t) +FUNC_FUNC(lxor, long, long) +FUNC_FUNC(lxor, ulong, unsigned long) + + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(lxor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC(band, int8_t, int8_t) +FUNC_FUNC(band, uint8_t, uint8_t) +FUNC_FUNC(band, int16_t, int16_t) +FUNC_FUNC(band, uint16_t, uint16_t) +FUNC_FUNC(band, int32_t, int32_t) +FUNC_FUNC(band, uint32_t, uint32_t) +FUNC_FUNC(band, int64_t, int64_t) +FUNC_FUNC(band, uint64_t, uint64_t) +FUNC_FUNC(band, long, long) +FUNC_FUNC(band, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(band, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(band, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(band, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(band, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(band, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(band, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC(bor, int8_t, int8_t) +FUNC_FUNC(bor, uint8_t, uint8_t) +FUNC_FUNC(bor, int16_t, int16_t) +FUNC_FUNC(bor, uint16_t, uint16_t) +FUNC_FUNC(bor, int32_t, int32_t) +FUNC_FUNC(bor, uint32_t, uint32_t) +FUNC_FUNC(bor, int64_t, int64_t) +FUNC_FUNC(bor, uint64_t, uint64_t) +FUNC_FUNC(bor, long, long) +FUNC_FUNC(bor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(bor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(bor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(bor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(bor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(bor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(bor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC(bxor, int8_t, int8_t) +FUNC_FUNC(bxor, uint8_t, uint8_t) +FUNC_FUNC(bxor, int16_t, int16_t) +FUNC_FUNC(bxor, uint16_t, uint16_t) +FUNC_FUNC(bxor, int32_t, int32_t) +FUNC_FUNC(bxor, uint32_t, uint32_t) +FUNC_FUNC(bxor, int64_t, int64_t) +FUNC_FUNC(bxor, uint64_t, uint64_t) +FUNC_FUNC(bxor, long, long) +FUNC_FUNC(bxor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(bxor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(bxor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(bxor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(bxor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(bxor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(bxor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC(maxloc, float_int, >) +LOC_FUNC(maxloc, double_int, >) +LOC_FUNC(maxloc, long_int, >) +LOC_FUNC(maxloc, 2int, >) +LOC_FUNC(maxloc, short_int, >) +LOC_FUNC(maxloc, long_double_int, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC(minloc, float_int, <) +LOC_FUNC(minloc, double_int, <) +LOC_FUNC(minloc, long_int, <) +LOC_FUNC(minloc, 2int, <) +LOC_FUNC(minloc, short_int, <) +LOC_FUNC(minloc, long_double_int, <) + + + +/* + * This is a three buffer (2 input and 1 output) version of the reduction + * routines, needed for some optimizations. + */ +#define FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_rocm_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + int threads_per_block, max_blocks; \ + int source1_device, source2_device, target_device; \ + type *source1, *source2, *target; \ + int n = *count; \ + device_op_pre(in1, (void**)&source1, &source1_device, \ + in2, (void**)&source2, &source2_device, \ + out, (void**)&target, &target_device, \ + n, *dtype, \ + &threads_per_block, &max_blocks, &device, stream); \ + hipStream_t *custream = (hipStream_t*)stream->stream; \ + ompi_op_rocm_3buff_##name##_##type_name##_submit(source1, source2, target, n, threads_per_block, max_blocks, *custream);\ + device_op_post(source1, source1_device, source2, source2_device, out, target, target_device, n, *dtype, device, stream);\ + } + + +#define OP_FUNC_3BUF(name, type_name, type, op, ...) FUNC_3BUF(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* reuse the macro above, no work is actually done so we don't care about the func */ +#define FUNC_FUNC_3BUF(name, type_name, type, ...) FUNC_3BUF(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_FUNC_3BUF(name, type_name, op) FUNC_3BUF(name, type_name, ompi_op_predefined_##type_name##_t) + + +/* Dispatch Fortran types to C types */ +#define FORT_INT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_rocm_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t)); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_rocm_3buff_##name##_int8_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_rocm_3buff_##name##_int16_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_rocm_3buff_##name##_int32_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_rocm_3buff_##name##_int64_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + } \ + } + +/* Dispatch Fortran types to C types */ +#define FORT_FLOAT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_rocm_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(long double)); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_rocm_3buff_##name##_float(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_rocm_3buff_##name##_double(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(long double): \ + ompi_op_rocm_3buff_##name##_long_double(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + } \ + } + + +/************************************************************************* + * Max + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(max, int8_t, int8_t) +FUNC_FUNC_3BUF(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF(max, int16_t, int16_t) +FUNC_FUNC_3BUF(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF(max, int32_t, int32_t) +FUNC_FUNC_3BUF(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF(max, int64_t, int64_t) +FUNC_FUNC_3BUF(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF(max, long, long) +FUNC_FUNC_3BUF(max, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(max, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(max, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(max, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(max, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(max, fortran_integer8, ompi_fortran_integer8_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(max, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_3BUF(max, float, float) +FUNC_FUNC_3BUF(max, double, double) +FUNC_FUNC_3BUF(max, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(max, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(max, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(max, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(max, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(max, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(max, fortran_real16, ompi_fortran_real16_t) +#endif + + +/************************************************************************* + * Min + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(min, int8_t, int8_t) +FUNC_FUNC_3BUF(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF(min, int16_t, int16_t) +FUNC_FUNC_3BUF(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF(min, int32_t, int32_t) +FUNC_FUNC_3BUF(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF(min, int64_t, int64_t) +FUNC_FUNC_3BUF(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF(min, long, long) +FUNC_FUNC_3BUF(min, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(min, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(min, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(min, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(min, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(min, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(min, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(min, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_3BUF(min, float, float) +FUNC_FUNC_3BUF(min, double, double) +FUNC_FUNC_3BUF(min, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(min, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(min, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(min, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(min, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(min, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(min, fortran_real16, ompi_fortran_real16_t) +#endif + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(sum, int8_t, int8_t, +) +OP_FUNC_3BUF(sum, uint8_t, uint8_t, +) +OP_FUNC_3BUF(sum, int16_t, int16_t, +) +OP_FUNC_3BUF(sum, uint16_t, uint16_t, +) +OP_FUNC_3BUF(sum, int32_t, int32_t, +) +OP_FUNC_3BUF(sum, uint32_t, uint32_t, +) +OP_FUNC_3BUF(sum, int64_t, int64_t, +) +OP_FUNC_3BUF(sum, uint64_t, uint64_t, +) +OP_FUNC_3BUF(sum, long, long, +) +OP_FUNC_3BUF(sum, ulong, unsigned long, +) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(sum, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(sum, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(sum, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(sum, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(sum, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(sum, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF(sum, short_float, short float, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF(sum, short_float, opal_short_float_t, +) +#endif +#endif // 0 +OP_FUNC_3BUF(sum, float, float, +) +OP_FUNC_3BUF(sum, double, double, +) +OP_FUNC_3BUF(sum, long_double, long double, +) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(sum, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(sum, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(sum, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(sum, c_short_float_complex, short float _Complex, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(sum, c_long_double_complex, long double _Complex, +) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCaddf(a,b)) +FUNC_FUNC_3BUF(sum, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCadd(a,b)) +FUNC_FUNC_3BUF(sum, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(prod, int8_t, int8_t, *) +OP_FUNC_3BUF(prod, uint8_t, uint8_t, *) +OP_FUNC_3BUF(prod, int16_t, int16_t, *) +OP_FUNC_3BUF(prod, uint16_t, uint16_t, *) +OP_FUNC_3BUF(prod, int32_t, int32_t, *) +OP_FUNC_3BUF(prod, uint32_t, uint32_t, *) +OP_FUNC_3BUF(prod, int64_t, int64_t, *) +OP_FUNC_3BUF(prod, uint64_t, uint64_t, *) +OP_FUNC_3BUF(prod, long, long, *) +OP_FUNC_3BUF(prod, ulong, unsigned long, *) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(prod, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(prod, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(prod, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(prod, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(prod, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(prod, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FORT_FLOAT_FUNC_3BUF(prod, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FORT_FLOAT_FUNC_3BUF(prod, short_float, opal_short_float_t) +#endif +#endif // 0 +OP_FUNC_3BUF(prod, float, float, *) +OP_FUNC_3BUF(prod, double, double, *) +OP_FUNC_3BUF(prod, long_double, long double, *) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(prod, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(prod, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(prod, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(prod, c_short_float_complex, short float _Complex, *) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(prod, c_long_double_complex, long double _Complex, *) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCmulf(a,b)) +FUNC_FUNC_3BUF(prod, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCmul(a,b)) +FUNC_FUNC_3BUF(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC_3BUF(land, int8_t, int8_t) +FUNC_FUNC_3BUF(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF(land, int16_t, int16_t) +FUNC_FUNC_3BUF(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF(land, int32_t, int32_t) +FUNC_FUNC_3BUF(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF(land, int64_t, int64_t) +FUNC_FUNC_3BUF(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF(land, long, long) +FUNC_FUNC_3BUF(land, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(land, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC_3BUF(lor, int8_t, int8_t) +FUNC_FUNC_3BUF(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lor, int16_t, int16_t) +FUNC_FUNC_3BUF(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lor, int32_t, int32_t) +FUNC_FUNC_3BUF(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lor, int64_t, int64_t) +FUNC_FUNC_3BUF(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lor, long, long) +FUNC_FUNC_3BUF(lor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(lor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC_3BUF(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lxor, long, long) +FUNC_FUNC_3BUF(lxor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(lxor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC_3BUF(band, int8_t, int8_t) +FUNC_FUNC_3BUF(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF(band, int16_t, int16_t) +FUNC_FUNC_3BUF(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF(band, int32_t, int32_t) +FUNC_FUNC_3BUF(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF(band, int64_t, int64_t) +FUNC_FUNC_3BUF(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF(band, long, long) +FUNC_FUNC_3BUF(band, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(band, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(band, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(band, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(band, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(band, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(band, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC_3BUF(bor, int8_t, int8_t) +FUNC_FUNC_3BUF(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bor, int16_t, int16_t) +FUNC_FUNC_3BUF(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bor, int32_t, int32_t) +FUNC_FUNC_3BUF(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bor, int64_t, int64_t) +FUNC_FUNC_3BUF(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bor, long, long) +FUNC_FUNC_3BUF(bor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(bor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(bor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(bor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(bor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(bor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(bor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC_3BUF(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bxor, long, long) +FUNC_FUNC_3BUF(bxor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(bxor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(bxor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(bxor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(bxor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(bxor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(bxor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_3BUF(maxloc, float_int, >) +LOC_FUNC_3BUF(maxloc, double_int, >) +LOC_FUNC_3BUF(maxloc, long_int, >) +LOC_FUNC_3BUF(maxloc, 2int, >) +LOC_FUNC_3BUF(maxloc, short_int, >) +LOC_FUNC_3BUF(maxloc, long_double_int, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_3BUF(minloc, float_int, <) +LOC_FUNC_3BUF(minloc, double_int, <) +LOC_FUNC_3BUF(minloc, long_int, <) +LOC_FUNC_3BUF(minloc, 2int, <) +LOC_FUNC_3BUF(minloc, short_int, <) +LOC_FUNC_3BUF(minloc, long_double_int, <) + +/* + * Helpful defines, because there's soooo many names! + * + * **NOTE** These #define's used to be strictly ordered but the use of + * designated initializers removed this restrictions. When adding new + * operators ALWAYS use a designated initializer! + */ + +/** C integer ***********************************************************/ +#define C_INTEGER(name, ftype) \ + [OMPI_OP_BASE_TYPE_INT8_T] = ompi_op_rocm_##ftype##_##name##_int8_t, \ + [OMPI_OP_BASE_TYPE_UINT8_T] = ompi_op_rocm_##ftype##_##name##_uint8_t, \ + [OMPI_OP_BASE_TYPE_INT16_T] = ompi_op_rocm_##ftype##_##name##_int16_t, \ + [OMPI_OP_BASE_TYPE_UINT16_T] = ompi_op_rocm_##ftype##_##name##_uint16_t, \ + [OMPI_OP_BASE_TYPE_INT32_T] = ompi_op_rocm_##ftype##_##name##_int32_t, \ + [OMPI_OP_BASE_TYPE_UINT32_T] = ompi_op_rocm_##ftype##_##name##_uint32_t, \ + [OMPI_OP_BASE_TYPE_INT64_T] = ompi_op_rocm_##ftype##_##name##_int64_t, \ + [OMPI_OP_BASE_TYPE_LONG] = ompi_op_rocm_##ftype##_##name##_long, \ + [OMPI_OP_BASE_TYPE_UNSIGNED_LONG] = ompi_op_rocm_##ftype##_##name##_ulong, \ + [OMPI_OP_BASE_TYPE_UINT64_T] = ompi_op_rocm_##ftype##_##name##_uint64_t + +/** All the Fortran integers ********************************************/ + +#if OMPI_HAVE_FORTRAN_INTEGER +#define FORTRAN_INTEGER_PLAIN(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer +#else +#define FORTRAN_INTEGER_PLAIN(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +#define FORTRAN_INTEGER1(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer1 +#else +#define FORTRAN_INTEGER1(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +#define FORTRAN_INTEGER2(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer2 +#else +#define FORTRAN_INTEGER2(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +#define FORTRAN_INTEGER4(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer4 +#else +#define FORTRAN_INTEGER4(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +#define FORTRAN_INTEGER8(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer8 +#else +#define FORTRAN_INTEGER8(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +#define FORTRAN_INTEGER16(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer16 +#else +#define FORTRAN_INTEGER16(name, ftype) NULL +#endif + +#define FORTRAN_INTEGER(name, ftype) \ + [OMPI_OP_BASE_TYPE_INTEGER] = FORTRAN_INTEGER_PLAIN(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER1] = FORTRAN_INTEGER1(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER2] = FORTRAN_INTEGER2(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER4] = FORTRAN_INTEGER4(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER8] = FORTRAN_INTEGER8(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER16] = FORTRAN_INTEGER16(name, ftype) + +/** All the Fortran reals ***********************************************/ + +#if OMPI_HAVE_FORTRAN_REAL +#define FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_real +#else +#define FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +#define FLOATING_POINT_FORTRAN_REAL2(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_real2 +#else +#define FLOATING_POINT_FORTRAN_REAL2(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +#define FLOATING_POINT_FORTRAN_REAL4(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_real4 +#else +#define FLOATING_POINT_FORTRAN_REAL4(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +#define FLOATING_POINT_FORTRAN_REAL8(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_real8 +#else +#define FLOATING_POINT_FORTRAN_REAL8(name, ftype) NULL +#endif +/* If: + - we have fortran REAL*16, *and* + - fortran REAL*16 matches the bit representation of the + corresponding C type + Only then do we put in function pointers for REAL*16 reductions. + Otherwise, just put in NULL. */ +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +#define FLOATING_POINT_FORTRAN_REAL16(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_real16 +#else +#define FLOATING_POINT_FORTRAN_REAL16(name, ftype) NULL +#endif + +#define FLOATING_POINT_FORTRAN_REAL(name, ftype) \ + [OMPI_OP_BASE_TYPE_REAL] = FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL2] = FLOATING_POINT_FORTRAN_REAL2(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL4] = FLOATING_POINT_FORTRAN_REAL4(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL8] = FLOATING_POINT_FORTRAN_REAL8(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL16] = FLOATING_POINT_FORTRAN_REAL16(name, ftype) + +/** Fortran double precision ********************************************/ + +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +#define FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype) \ + ompi_op_rocm_##ftype##_##name##_fortran_double_precision +#else +#define FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype) NULL +#endif + +/** Floating point, including all the Fortran reals *********************/ + +//#if defined(HAVE_SHORT_FLOAT) || defined(HAVE_OPAL_SHORT_FLOAT_T) +//#define SHORT_FLOAT(name, ftype) ompi_op_rocm_##ftype##_##name##_short_float +//#else +#define SHORT_FLOAT(name, ftype) NULL +//#endif +#define FLOAT(name, ftype) ompi_op_rocm_##ftype##_##name##_float +#define DOUBLE(name, ftype) ompi_op_rocm_##ftype##_##name##_double +#define LONG_DOUBLE(name, ftype) ompi_op_rocm_##ftype##_##name##_long_double + +#define FLOATING_POINT(name, ftype) \ + [OMPI_OP_BASE_TYPE_SHORT_FLOAT] = SHORT_FLOAT(name, ftype), \ + [OMPI_OP_BASE_TYPE_FLOAT] = FLOAT(name, ftype), \ + [OMPI_OP_BASE_TYPE_DOUBLE] = DOUBLE(name, ftype), \ + FLOATING_POINT_FORTRAN_REAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_DOUBLE_PRECISION] = FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype), \ + [OMPI_OP_BASE_TYPE_LONG_DOUBLE] = LONG_DOUBLE(name, ftype) + +/** Fortran logical *****************************************************/ + +#if OMPI_HAVE_FORTRAN_LOGICAL +#define FORTRAN_LOGICAL(name, ftype) \ + ompi_op_rocm_##ftype##_##name##_fortran_logical /* OMPI_OP_ROCM_TYPE_LOGICAL */ +#else +#define FORTRAN_LOGICAL(name, ftype) NULL +#endif + +#define LOGICAL(name, ftype) \ + [OMPI_OP_BASE_TYPE_LOGICAL] = FORTRAN_LOGICAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_BOOL] = ompi_op_rocm_##ftype##_##name##_bool + +/** Complex *****************************************************/ +#if 0 + +#if defined(HAVE_SHORT_FLOAT__COMPLEX) || defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +#define SHORT_FLOAT_COMPLEX(name, ftype) ompi_op_rocm_##ftype##_##name##_c_short_float_complex +#else +#define SHORT_FLOAT_COMPLEX(name, ftype) NULL +#endif +#define LONG_DOUBLE_COMPLEX(name, ftype) ompi_op_rocm_##ftype##_##name##_c_long_double_complex +#else +#define SHORT_FLOAT_COMPLEX(name, ftype) NULL +#define LONG_DOUBLE_COMPLEX(name, ftype) NULL +#endif // 0 +#define FLOAT_COMPLEX(name, ftype) ompi_op_rocm_##ftype##_##name##_c_float_complex +#define DOUBLE_COMPLEX(name, ftype) ompi_op_rocm_##ftype##_##name##_c_double_complex + +#define COMPLEX(name, ftype) \ + [OMPI_OP_BASE_TYPE_C_SHORT_FLOAT_COMPLEX] = SHORT_FLOAT_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_FLOAT_COMPLEX] = FLOAT_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_DOUBLE_COMPLEX] = DOUBLE_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_LONG_DOUBLE_COMPLEX] = LONG_DOUBLE_COMPLEX(name, ftype) + +/** Byte ****************************************************************/ + +#define BYTE(name, ftype) \ + [OMPI_OP_BASE_TYPE_BYTE] = ompi_op_rocm_##ftype##_##name##_byte + +/** Fortran complex *****************************************************/ +/** Fortran "2" types ***************************************************/ + +#if OMPI_HAVE_FORTRAN_REAL && OMPI_SIZEOF_FLOAT == OMPI_SIZEOF_FORTRAN_REAL +#define TWOLOC_FORTRAN_2REAL(name, ftype) ompi_op_rocm_##ftype##_##name##_2double_precision +#else +#define TWOLOC_FORTRAN_2REAL(name, ftype) NULL +#endif + +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION && OMPI_SIZEOF_DOUBLE == OMPI_SIZEOF_FORTRAN_DOUBLE_PRECISION +#define TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype) ompi_op_rocm_##ftype##_##name##_2double_precision +#else +#define TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER && OMPI_SIZEOF_INT == OMPI_SIZEOF_FORTRAN_INTEGER +#define TWOLOC_FORTRAN_2INTEGER(name, ftype) ompi_op_rocm_##ftype##_##name##_2int +#else +#define TWOLOC_FORTRAN_2INTEGER(name, ftype) NULL +#endif + +/** All "2" types *******************************************************/ + +#define TWOLOC(name, ftype) \ + [OMPI_OP_BASE_TYPE_2REAL] = TWOLOC_FORTRAN_2REAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_2DOUBLE_PRECISION] = TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype), \ + [OMPI_OP_BASE_TYPE_2INTEGER] = TWOLOC_FORTRAN_2INTEGER(name, ftype), \ + [OMPI_OP_BASE_TYPE_FLOAT_INT] = ompi_op_rocm_##ftype##_##name##_float_int, \ + [OMPI_OP_BASE_TYPE_DOUBLE_INT] = ompi_op_rocm_##ftype##_##name##_double_int, \ + [OMPI_OP_BASE_TYPE_LONG_INT] = ompi_op_rocm_##ftype##_##name##_long_int, \ + [OMPI_OP_BASE_TYPE_2INT] = ompi_op_rocm_##ftype##_##name##_2int, \ + [OMPI_OP_BASE_TYPE_SHORT_INT] = ompi_op_rocm_##ftype##_##name##_short_int, \ + [OMPI_OP_BASE_TYPE_LONG_DOUBLE_INT] = ompi_op_rocm_##ftype##_##name##_long_double_int + +/* + * MPI_OP_NULL + * All types + */ +#define FLAGS_NO_FLOAT \ + (OMPI_OP_FLAGS_INTRINSIC | OMPI_OP_FLAGS_ASSOC | OMPI_OP_FLAGS_COMMUTE) +#define FLAGS \ + (OMPI_OP_FLAGS_INTRINSIC | OMPI_OP_FLAGS_ASSOC | \ + OMPI_OP_FLAGS_FLOAT_ASSOC | OMPI_OP_FLAGS_COMMUTE) + +ompi_op_base_stream_handler_fn_t ompi_op_rocm_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX] = + { + /* Corresponds to MPI_OP_NULL */ + [OMPI_OP_BASE_FORTRAN_NULL] = { + /* Leaving this empty puts in NULL for all entries */ + NULL, + }, + /* Corresponds to MPI_MAX */ + [OMPI_OP_BASE_FORTRAN_MAX] = { + C_INTEGER(max, 2buff), + FORTRAN_INTEGER(max, 2buff), + FLOATING_POINT(max, 2buff), + }, + /* Corresponds to MPI_MIN */ + [OMPI_OP_BASE_FORTRAN_MIN] = { + C_INTEGER(min, 2buff), + FORTRAN_INTEGER(min, 2buff), + FLOATING_POINT(min, 2buff), + }, + /* Corresponds to MPI_SUM */ + [OMPI_OP_BASE_FORTRAN_SUM] = { + C_INTEGER(sum, 2buff), + FORTRAN_INTEGER(sum, 2buff), + FLOATING_POINT(sum, 2buff), + COMPLEX(sum, 2buff), + }, + /* Corresponds to MPI_PROD */ + [OMPI_OP_BASE_FORTRAN_PROD] = { + C_INTEGER(prod, 2buff), + FORTRAN_INTEGER(prod, 2buff), + FLOATING_POINT(prod, 2buff), + COMPLEX(prod, 2buff), + }, + /* Corresponds to MPI_LAND */ + [OMPI_OP_BASE_FORTRAN_LAND] = { + C_INTEGER(land, 2buff), + LOGICAL(land, 2buff), + }, + /* Corresponds to MPI_BAND */ + [OMPI_OP_BASE_FORTRAN_BAND] = { + C_INTEGER(band, 2buff), + FORTRAN_INTEGER(band, 2buff), + BYTE(band, 2buff), + }, + /* Corresponds to MPI_LOR */ + [OMPI_OP_BASE_FORTRAN_LOR] = { + C_INTEGER(lor, 2buff), + LOGICAL(lor, 2buff), + }, + /* Corresponds to MPI_BOR */ + [OMPI_OP_BASE_FORTRAN_BOR] = { + C_INTEGER(bor, 2buff), + FORTRAN_INTEGER(bor, 2buff), + BYTE(bor, 2buff), + }, + /* Corresponds to MPI_LXOR */ + [OMPI_OP_BASE_FORTRAN_LXOR] = { + C_INTEGER(lxor, 2buff), + LOGICAL(lxor, 2buff), + }, + /* Corresponds to MPI_BXOR */ + [OMPI_OP_BASE_FORTRAN_BXOR] = { + C_INTEGER(bxor, 2buff), + FORTRAN_INTEGER(bxor, 2buff), + BYTE(bxor, 2buff), + }, + /* Corresponds to MPI_MAXLOC */ + [OMPI_OP_BASE_FORTRAN_MAXLOC] = { + TWOLOC(maxloc, 2buff), + }, + /* Corresponds to MPI_MINLOC */ + [OMPI_OP_BASE_FORTRAN_MINLOC] = { + TWOLOC(minloc, 2buff), + }, + /* Corresponds to MPI_REPLACE */ + [OMPI_OP_BASE_FORTRAN_REPLACE] = { + /* (MPI_ACCUMULATE is handled differently than the other + reductions, so just zero out its function + implementations here to ensure that users don't invoke + MPI_REPLACE with any reduction operations other than + ACCUMULATE) */ + NULL, + }, + + }; + +ompi_op_base_3buff_stream_handler_fn_t ompi_op_rocm_3buff_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX] = + { + /* Corresponds to MPI_OP_NULL */ + [OMPI_OP_BASE_FORTRAN_NULL] = { + /* Leaving this empty puts in NULL for all entries */ + NULL, + }, + /* Corresponds to MPI_MAX */ + [OMPI_OP_BASE_FORTRAN_MAX] = { + C_INTEGER(max, 3buff), + FORTRAN_INTEGER(max, 3buff), + FLOATING_POINT(max, 3buff), + }, + /* Corresponds to MPI_MIN */ + [OMPI_OP_BASE_FORTRAN_MIN] = { + C_INTEGER(min, 3buff), + FORTRAN_INTEGER(min, 3buff), + FLOATING_POINT(min, 3buff), + }, + /* Corresponds to MPI_SUM */ + [OMPI_OP_BASE_FORTRAN_SUM] = { + C_INTEGER(sum, 3buff), + FORTRAN_INTEGER(sum, 3buff), + FLOATING_POINT(sum, 3buff), + COMPLEX(sum, 3buff), + }, + /* Corresponds to MPI_PROD */ + [OMPI_OP_BASE_FORTRAN_PROD] = { + C_INTEGER(prod, 3buff), + FORTRAN_INTEGER(prod, 3buff), + FLOATING_POINT(prod, 3buff), + COMPLEX(prod, 3buff), + }, + /* Corresponds to MPI_LAND */ + [OMPI_OP_BASE_FORTRAN_LAND] ={ + C_INTEGER(land, 3buff), + LOGICAL(land, 3buff), + }, + /* Corresponds to MPI_BAND */ + [OMPI_OP_BASE_FORTRAN_BAND] = { + C_INTEGER(band, 3buff), + FORTRAN_INTEGER(band, 3buff), + BYTE(band, 3buff), + }, + /* Corresponds to MPI_LOR */ + [OMPI_OP_BASE_FORTRAN_LOR] = { + C_INTEGER(lor, 3buff), + LOGICAL(lor, 3buff), + }, + /* Corresponds to MPI_BOR */ + [OMPI_OP_BASE_FORTRAN_BOR] = { + C_INTEGER(bor, 3buff), + FORTRAN_INTEGER(bor, 3buff), + BYTE(bor, 3buff), + }, + /* Corresponds to MPI_LXOR */ + [OMPI_OP_BASE_FORTRAN_LXOR] = { + C_INTEGER(lxor, 3buff), + LOGICAL(lxor, 3buff), + }, + /* Corresponds to MPI_BXOR */ + [OMPI_OP_BASE_FORTRAN_BXOR] = { + C_INTEGER(bxor, 3buff), + FORTRAN_INTEGER(bxor, 3buff), + BYTE(bxor, 3buff), + }, + /* Corresponds to MPI_MAXLOC */ + [OMPI_OP_BASE_FORTRAN_MAXLOC] = { + TWOLOC(maxloc, 3buff), + }, + /* Corresponds to MPI_MINLOC */ + [OMPI_OP_BASE_FORTRAN_MINLOC] = { + TWOLOC(minloc, 3buff), + }, + /* Corresponds to MPI_REPLACE */ + [OMPI_OP_BASE_FORTRAN_REPLACE] = { + /* MPI_ACCUMULATE is handled differently than the other + reductions, so just zero out its function + implementations here to ensure that users don't invoke + MPI_REPLACE with any reduction operations other than + ACCUMULATE */ + NULL, + }, + }; diff --git a/ompi/mca/op/rocm/op_rocm_impl.cpp b/ompi/mca/op/rocm/op_rocm_impl.cpp new file mode 100644 index 00000000000..28142ce0e0b --- /dev/null +++ b/ompi/mca/op/rocm/op_rocm_impl.cpp @@ -0,0 +1,1053 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "hip/hip_runtime.h" +#include +#include + +#include + +#include "op_rocm_impl.h" + +//#define DO_NOT_USE_INTRINSICS 1 +#define USE_VECTORS 1 + +#include + +#define ISSIGNED(x) std::is_signed_v + +template +static inline __device__ constexpr T tmax(T a, T b) { + return (a > b) ? a : b; +} + +template +static inline __device__ constexpr T tmin(T a, T b) { + return (a < b) ? a : b; +} + +template +static inline __device__ constexpr T tsum(T a, T b) { + return a+b; +} + +template +static inline __device__ constexpr T tprod(T a, T b) { + return a*b; +} + +template +static inline __device__ T vmax(const T& a, const T& b) { + return T{tmax(a.x, b.x), tmax(a.y, b.y), tmax(a.z, b.z), tmax(a.w, b.w)}; +} + +template +static inline __device__ T vmin(const T& a, const T& b) { + return T{tmin(a.x, b.x), tmin(a.y, b.y), tmin(a.z, b.z), tmin(a.w, b.w)}; +} + +template +static inline __device__ T vsum(const T& a, const T& b) { + return T{tsum(a.x, b.x), tsum(a.y, b.y), tsum(a.z, b.z), tsum(a.w, b.w)}; +} + +template +static inline __device__ T vprod(const T& a, const T& b) { + return T{(a.x * b.x), (a.y * b.y), (a.z * b.z), (a.w * b.w)}; +} + + +/* TODO: missing support for + * - short float (conditional on whether short float is available) + * - complex + */ + +#define VECLEN 2 +#define VECTYPE(t) t##VECLEN + +#define OP_FUNC(name, type_name, type, op) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + inout[i] = inout[i] op in[i]; \ + } \ + } \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + int n = count; \ + hipStream_t s = stream; \ + hipLaunchKernelGGL(ompi_op_rocm_2buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, s, \ + in, inout, n); \ + } + +#if defined(USE_VECTORS) +#define OPV_FUNC(name, type_name, type, vtype, vlen, op) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n/vlen; i += stride) { \ + ((vtype*)inout)[i] = ((vtype*)inout)[i] op ((vtype*)in)[i]; \ + } \ + int remainder = n%vlen; \ + if (index == (n/vlen) && remainder != 0) { \ + while(remainder) { \ + int idx = n - remainder--; \ + inout[idx] = inout[idx] op in[idx]; \ + } \ + } \ + } \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int vcount = (count + vlen-1)/vlen; \ + int threads = min(threads_per_block, vcount); \ + int blocks = min((vcount + threads-1) / threads, max_blocks); \ + int n = count; \ + hipStream_t s = stream; \ + hipLaunchKernelGGL(ompi_op_rocm_2buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, s, \ + in, inout, n); \ + } +#else // USE_VECTORS +#define OPV_FUNC(name, type_name, type, vtype, vlen, op) OP_FUNC(name, type_name, type, op) +#endif // USE_VECTORS + + +#define FUNC_FUNC(name, type_name, type) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + inout[i] = current_func(inout[i], in[i]); \ + } \ + } \ + void \ + ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + int n = count; \ + hipStream_t s = stream; \ + hipLaunchKernelGGL(ompi_op_rocm_2buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, s, \ + in, inout, n); \ + } + + +#if defined(USE_VECTORS) +#define VFUNC_FUNC(name, type_name, type, vtype, vlen, vfn, fn) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n/vlen; i += stride) { \ + ((vtype*)inout)[i] = vfn(((vtype*)inout)[i], ((vtype*)in)[i]); \ + } \ + int remainder = n%vlen; \ + if (index == (n/vlen) && remainder != 0) { \ + while(remainder) { \ + int idx = n - remainder--; \ + inout[idx] = fn(inout[idx], in[idx]); \ + } \ + } \ + } \ + static void \ + ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int vcount = (count + vlen-1)/vlen; \ + int threads = min(threads_per_block, vcount); \ + int blocks = min((vcount + threads-1) / threads, max_blocks); \ + int n = count; \ + hipStream_t s = stream; \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } +#else +#define VFUNC_FUNC(name, type_name, type, vtype, vlen, vfn, fn) FUNC_FUNC_FN(name, type_name, type, fn) +#endif // defined(USE_VECTORS) + +#define FUNC_FUNC_FN(name, type_name, type, fn) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + inout[i] = fn(inout[i], in[i]); \ + } \ + } \ + void \ + ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + int n = count; \ + hipStream_t s = stream; \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ + +#define LOC_FUNC(name, type_name, op) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const ompi_op_predefined_##type_name##_t *__restrict__ in, \ + ompi_op_predefined_##type_name##_t *__restrict__ inout, \ + int n) \ + { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + const ompi_op_predefined_##type_name##_t *a = &in[i]; \ + ompi_op_predefined_##type_name##_t *b = &inout[i]; \ + if (a->v op b->v) { \ + b->v = a->v; \ + b->k = a->k; \ + } else if (a->v == b->v) { \ + b->k = (b->k < a->k ? b->k : a->k); \ + } \ + } \ + } \ + void \ + ompi_op_rocm_2buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + hipStream_t s = stream; \ + hipLaunchKernelGGL(ompi_op_rocm_2buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, s, \ + a, b, count); \ + } + + +#define OPV_DISPATCH(name, type_name, type) \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + static_assert(sizeof(type_name) <= sizeof(unsigned long long), "Unknown size type"); \ + if constexpr(!ISSIGNED(type)) { \ + if constexpr(sizeof(type_name) == sizeof(unsigned char)) { \ + ompi_op_rocm_2buff_##name##_uchar_submit((const unsigned char*)in, (unsigned char*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned short)) { \ + ompi_op_rocm_2buff_##name##_ushort_submit((const unsigned short*)in, (unsigned short*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned int)) { \ + ompi_op_rocm_2buff_##name##_uint_submit((const unsigned int*)in, (unsigned int*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned long)) { \ + ompi_op_rocm_2buff_##name##_ulong_submit((const unsigned long*)in, (unsigned long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned long long)) { \ + ompi_op_rocm_2buff_##name##_ulonglong_submit((const unsigned long long*)in, (unsigned long long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } \ + } else { \ + if constexpr(sizeof(type_name) == sizeof(char)) { \ + ompi_op_rocm_2buff_##name##_char_submit((const char*)in, (char*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(short)) { \ + ompi_op_rocm_2buff_##name##_short_submit((const short*)in, (short*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(int)) { \ + ompi_op_rocm_2buff_##name##_int_submit((const int*)in, (int*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(long)) { \ + ompi_op_rocm_2buff_##name##_long_submit((const long*)in, (long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(long long)) { \ + ompi_op_rocm_2buff_##name##_longlong_submit((const long long*)in, (long long*)inout, count,\ + threads_per_block, \ + max_blocks, stream); \ + } \ + } \ + } + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(max, char, char, char4, 4, vmax, max) +VFUNC_FUNC(max, uchar, unsigned char, uchar4, 4, vmax, max) +VFUNC_FUNC(max, short, short, short4, 4, vmax, max) +VFUNC_FUNC(max, ushort, unsigned short, ushort4, 4, vmax, max) +VFUNC_FUNC(max, int, int, int4, 4, vmax, max) +VFUNC_FUNC(max, uint, unsigned int, uint4, 4, vmax, max) + +#undef current_func +#define current_func(a, b) max(a, b) +FUNC_FUNC(max, long, long) +FUNC_FUNC(max, ulong, unsigned long) +FUNC_FUNC(max, longlong, long long) +FUNC_FUNC(max, ulonglong, unsigned long long) + + +/* dispatch fixed-size types */ +OPV_DISPATCH(max, int8_t, int8_t) +OPV_DISPATCH(max, uint8_t, uint8_t) +OPV_DISPATCH(max, int16_t, int16_t) +OPV_DISPATCH(max, uint16_t, uint16_t) +OPV_DISPATCH(max, int32_t, int32_t) +OPV_DISPATCH(max, uint32_t, uint32_t) +OPV_DISPATCH(max, int64_t, int64_t) +OPV_DISPATCH(max, uint64_t, uint64_t) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmaxf(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(max, float, float) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmax(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(max, double, double) + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +FUNC_FUNC(max, long_double, long double) + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(min, char, char, char4, 4, vmin, min) +VFUNC_FUNC(min, uchar, unsigned char, uchar4, 4, vmin, min) +VFUNC_FUNC(min, short, short, short4, 4, vmin, min) +VFUNC_FUNC(min, ushort, unsigned short, ushort4, 4, vmin, min) +VFUNC_FUNC(min, int, int, int4, 4, vmin, min) +VFUNC_FUNC(min, uint, unsigned int, uint4, 4, vmin, min) + +#undef current_func +#define current_func(a, b) min(a, b) +FUNC_FUNC(min, long, long) +FUNC_FUNC(min, ulong, unsigned long) +FUNC_FUNC(min, longlong, long long) +FUNC_FUNC(min, ulonglong, unsigned long long) +OPV_DISPATCH(min, int8_t, int8_t) +OPV_DISPATCH(min, uint8_t, uint8_t) +OPV_DISPATCH(min, int16_t, int16_t) +OPV_DISPATCH(min, uint16_t, uint16_t) +OPV_DISPATCH(min, int32_t, int32_t) +OPV_DISPATCH(min, uint32_t, uint32_t) +OPV_DISPATCH(min, int64_t, int64_t) +OPV_DISPATCH(min, uint64_t, uint64_t) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fminf(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(min, float, float) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmin(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(min, double, double) + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +FUNC_FUNC(min, long_double, long double) + + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(sum, char, char, char4, 4, vsum, tsum) +VFUNC_FUNC(sum, uchar, unsigned char, uchar4, 4, vsum, tsum) +VFUNC_FUNC(sum, short, short, short4, 4, vsum, tsum) +VFUNC_FUNC(sum, ushort, unsigned short, ushort4, 4, vsum, tsum) +VFUNC_FUNC(sum, int, int, int4, 4, vsum, tsum) +VFUNC_FUNC(sum, uint, unsigned int, uint4, 4, vsum, tsum) + +#undef current_func +#define current_func(a, b) tsum(a, b) +FUNC_FUNC(sum, long, long) +FUNC_FUNC(sum, ulong, unsigned long) +FUNC_FUNC(sum, longlong, long long) +FUNC_FUNC(sum, ulonglong, unsigned long long) + +OPV_DISPATCH(sum, int8_t, int8_t) +OPV_DISPATCH(sum, uint8_t, uint8_t) +OPV_DISPATCH(sum, int16_t, int16_t) +OPV_DISPATCH(sum, uint16_t, uint16_t) +OPV_DISPATCH(sum, int32_t, int32_t) +OPV_DISPATCH(sum, uint32_t, uint32_t) +OPV_DISPATCH(sum, int64_t, int64_t) +OPV_DISPATCH(sum, uint64_t, uint64_t) + +OPV_FUNC(sum, float, float, float4, 4, +) +OPV_FUNC(sum, double, double, double4, 4, +) +OP_FUNC(sum, long_double, long double, +) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(sum, c_short_float_complex, short float _Complex, +=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(sum, c_long_double_complex, cuLongDoubleComplex, +=) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCaddf(a,b)) +FUNC_FUNC(sum, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCadd(a,b)) +FUNC_FUNC(sum, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +#undef current_func +#define current_func(a, b) tprod(a, b) +FUNC_FUNC(prod, char, char) +FUNC_FUNC(prod, uchar, unsigned char) +FUNC_FUNC(prod, short, short) +FUNC_FUNC(prod, ushort, unsigned short) +FUNC_FUNC(prod, int, int) +FUNC_FUNC(prod, uint, unsigned int) +FUNC_FUNC(prod, long, long) +FUNC_FUNC(prod, ulong, unsigned long) +FUNC_FUNC(prod, longlong, long long) +FUNC_FUNC(prod, ulonglong, unsigned long long) + +OPV_DISPATCH(prod, int8_t, int8_t) +OPV_DISPATCH(prod, uint8_t, uint8_t) +OPV_DISPATCH(prod, int16_t, int16_t) +OPV_DISPATCH(prod, uint16_t, uint16_t) +OPV_DISPATCH(prod, int32_t, int32_t) +OPV_DISPATCH(prod, uint32_t, uint32_t) +OPV_DISPATCH(prod, int64_t, int64_t) +OPV_DISPATCH(prod, uint64_t, uint64_t) + + +OPV_FUNC(prod, float, float, float4, 4, *) +OPV_FUNC(prod, double, double, double4, 4, *) +OP_FUNC(prod, long_double, long double, *) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(sum, c_short_float_complex, short float _Complex, +=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(sum, c_long_double_complex, cuLongDoubleComplex, +=) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCmulf(a,b)) +FUNC_FUNC(prod, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCmul(a,b)) +FUNC_FUNC(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC(land, int8_t, int8_t) +FUNC_FUNC(land, uint8_t, uint8_t) +FUNC_FUNC(land, int16_t, int16_t) +FUNC_FUNC(land, uint16_t, uint16_t) +FUNC_FUNC(land, int32_t, int32_t) +FUNC_FUNC(land, uint32_t, uint32_t) +FUNC_FUNC(land, int64_t, int64_t) +FUNC_FUNC(land, uint64_t, uint64_t) +FUNC_FUNC(land, long, long) +FUNC_FUNC(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC(lor, int8_t, int8_t) +FUNC_FUNC(lor, uint8_t, uint8_t) +FUNC_FUNC(lor, int16_t, int16_t) +FUNC_FUNC(lor, uint16_t, uint16_t) +FUNC_FUNC(lor, int32_t, int32_t) +FUNC_FUNC(lor, uint32_t, uint32_t) +FUNC_FUNC(lor, int64_t, int64_t) +FUNC_FUNC(lor, uint64_t, uint64_t) +FUNC_FUNC(lor, long, long) +FUNC_FUNC(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC(lxor, int8_t, int8_t) +FUNC_FUNC(lxor, uint8_t, uint8_t) +FUNC_FUNC(lxor, int16_t, int16_t) +FUNC_FUNC(lxor, uint16_t, uint16_t) +FUNC_FUNC(lxor, int32_t, int32_t) +FUNC_FUNC(lxor, uint32_t, uint32_t) +FUNC_FUNC(lxor, int64_t, int64_t) +FUNC_FUNC(lxor, uint64_t, uint64_t) +FUNC_FUNC(lxor, long, long) +FUNC_FUNC(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC(band, int8_t, int8_t) +FUNC_FUNC(band, uint8_t, uint8_t) +FUNC_FUNC(band, int16_t, int16_t) +FUNC_FUNC(band, uint16_t, uint16_t) +FUNC_FUNC(band, int32_t, int32_t) +FUNC_FUNC(band, uint32_t, uint32_t) +FUNC_FUNC(band, int64_t, int64_t) +FUNC_FUNC(band, uint64_t, uint64_t) +FUNC_FUNC(band, long, long) +FUNC_FUNC(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC(bor, int8_t, int8_t) +FUNC_FUNC(bor, uint8_t, uint8_t) +FUNC_FUNC(bor, int16_t, int16_t) +FUNC_FUNC(bor, uint16_t, uint16_t) +FUNC_FUNC(bor, int32_t, int32_t) +FUNC_FUNC(bor, uint32_t, uint32_t) +FUNC_FUNC(bor, int64_t, int64_t) +FUNC_FUNC(bor, uint64_t, uint64_t) +FUNC_FUNC(bor, long, long) +FUNC_FUNC(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC(bxor, int8_t, int8_t) +FUNC_FUNC(bxor, uint8_t, uint8_t) +FUNC_FUNC(bxor, int16_t, int16_t) +FUNC_FUNC(bxor, uint16_t, uint16_t) +FUNC_FUNC(bxor, int32_t, int32_t) +FUNC_FUNC(bxor, uint32_t, uint32_t) +FUNC_FUNC(bxor, int64_t, int64_t) +FUNC_FUNC(bxor, uint64_t, uint64_t) +FUNC_FUNC(bxor, long, long) +FUNC_FUNC(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC(maxloc, float_int, >) +LOC_FUNC(maxloc, double_int, >) +LOC_FUNC(maxloc, long_int, >) +LOC_FUNC(maxloc, 2int, >) +LOC_FUNC(maxloc, short_int, >) +LOC_FUNC(maxloc, long_double_int, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC(minloc, float_int, <) +LOC_FUNC(minloc, double_int, <) +LOC_FUNC(minloc, long_int, <) +LOC_FUNC(minloc, 2int, <) +LOC_FUNC(minloc, short_int, <) +LOC_FUNC(minloc, long_double_int, <) + + +/* + * This is a three buffer (2 input and 1 output) version of the reduction + * routines, needed for some optimizations. + */ +#define OP_FUNC_3BUF(name, type_name, type, op) \ + static __global__ void \ + ompi_op_rocm_3buff_##name##_##type_name##_kernel(const type *__restrict__ in1, \ + const type *__restrict__ in2, \ + type *__restrict__ out, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + out[i] = in1[i] op in2[i]; \ + } \ + } \ + void ompi_op_rocm_3buff_##name##_##type_name##_submit(const type *in1, const type *in2, \ + type *out, int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + hipLaunchKernelGGL(ompi_op_rocm_3buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, stream, \ + in1, in2, out, count); \ + } + + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for (out = op(in1, in2)) + */ +#define FUNC_FUNC_3BUF(name, type_name, type) \ + static __global__ void \ + ompi_op_rocm_3buff_##name##_##type_name##_kernel(const type *__restrict__ in1, \ + const type *__restrict__ in2, \ + type *__restrict__ out, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + out[i] = current_func(in1[i], in2[i]); \ + } \ + } \ + void \ + ompi_op_rocm_3buff_##name##_##type_name##_submit(const type *in1, const type *in2, \ + type *out, int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + hipLaunchKernelGGL(ompi_op_rocm_3buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, stream, \ + in1, in2, out, count); \ + } + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +/* +#define LOC_STRUCT(type_name, type1, type2) \ + typedef struct { \ + type1 v; \ + type2 k; \ + } ompi_op_predefined_##type_name##_t; +*/ + +#define LOC_FUNC_3BUF(name, type_name, op) \ + static __global__ void \ + ompi_op_rocm_3buff_##name##_##type_name##_kernel(const ompi_op_predefined_##type_name##_t *__restrict__ in1, \ + const ompi_op_predefined_##type_name##_t *__restrict__ in2, \ + ompi_op_predefined_##type_name##_t *__restrict__ out, \ + int n) \ + { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + const ompi_op_predefined_##type_name##_t *a1 = &in1[i]; \ + const ompi_op_predefined_##type_name##_t *a2 = &in2[i]; \ + ompi_op_predefined_##type_name##_t *b = &out[i]; \ + if (a1->v op a2->v) { \ + b->v = a1->v; \ + b->k = a1->k; \ + } else if (a1->v == a2->v) { \ + b->v = a1->v; \ + b->k = (a2->k < a1->k ? a2->k : a1->k); \ + } else { \ + b->v = a2->v; \ + b->k = a2->k; \ + } \ + } \ + } \ + void \ + ompi_op_rocm_3buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *__restrict__ in1, \ + const ompi_op_predefined_##type_name##_t *__restrict__ in2, \ + ompi_op_predefined_##type_name##_t *__restrict__ out, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) \ + { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + hipLaunchKernelGGL(ompi_op_rocm_3buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, stream, \ + in1, in2, out, count); \ + } + + +/************************************************************************* + * Max + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(max, int8_t, int8_t) +FUNC_FUNC_3BUF(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF(max, int16_t, int16_t) +FUNC_FUNC_3BUF(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF(max, int32_t, int32_t) +FUNC_FUNC_3BUF(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF(max, int64_t, int64_t) +FUNC_FUNC_3BUF(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF(max, long, long) +FUNC_FUNC_3BUF(max, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(max, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF(max, float, float) +FUNC_FUNC_3BUF(max, double, double) +FUNC_FUNC_3BUF(max, long_double, long double) + +/************************************************************************* + * Min + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(min, int8_t, int8_t) +FUNC_FUNC_3BUF(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF(min, int16_t, int16_t) +FUNC_FUNC_3BUF(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF(min, int32_t, int32_t) +FUNC_FUNC_3BUF(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF(min, int64_t, int64_t) +FUNC_FUNC_3BUF(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF(min, long, long) +FUNC_FUNC_3BUF(min, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(min, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF(min, float, float) +FUNC_FUNC_3BUF(min, double, double) +FUNC_FUNC_3BUF(min, long_double, long double) + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(sum, int8_t, int8_t, +) +OP_FUNC_3BUF(sum, uint8_t, uint8_t, +) +OP_FUNC_3BUF(sum, int16_t, int16_t, +) +OP_FUNC_3BUF(sum, uint16_t, uint16_t, +) +OP_FUNC_3BUF(sum, int32_t, int32_t, +) +OP_FUNC_3BUF(sum, uint32_t, uint32_t, +) +OP_FUNC_3BUF(sum, int64_t, int64_t, +) +OP_FUNC_3BUF(sum, uint64_t, uint64_t, +) +OP_FUNC_3BUF(sum, long, long, +) +OP_FUNC_3BUF(sum, ulong, unsigned long, +) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF(sum, short_float, short float, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF(sum, short_float, opal_short_float_t, +) +#endif +OP_FUNC_3BUF(sum, float, float, +) +OP_FUNC_3BUF(sum, double, double, +) +OP_FUNC_3BUF(sum, long_double, long double, +) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(sum, c_short_float_complex, short float _Complex, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(sum, c_long_double_complex, cuLongDoubleComplex, +) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCaddf(a,b)) +FUNC_FUNC_3BUF(sum, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCadd(a,b)) +FUNC_FUNC_3BUF(sum, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(prod, int8_t, int8_t, *) +OP_FUNC_3BUF(prod, uint8_t, uint8_t, *) +OP_FUNC_3BUF(prod, int16_t, int16_t, *) +OP_FUNC_3BUF(prod, uint16_t, uint16_t, *) +OP_FUNC_3BUF(prod, int32_t, int32_t, *) +OP_FUNC_3BUF(prod, uint32_t, uint32_t, *) +OP_FUNC_3BUF(prod, int64_t, int64_t, *) +OP_FUNC_3BUF(prod, uint64_t, uint64_t, *) +OP_FUNC_3BUF(prod, long, long, *) +OP_FUNC_3BUF(prod, ulong, unsigned long, *) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(prod, c_short_float_complex, short float _Complex, *) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(prod, c_long_double_complex, long double _Complex, *) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCmulf(a,b)) +FUNC_FUNC_3BUF(prod, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCmul(a,b)) +FUNC_FUNC_3BUF(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC_3BUF(land, int8_t, int8_t) +FUNC_FUNC_3BUF(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF(land, int16_t, int16_t) +FUNC_FUNC_3BUF(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF(land, int32_t, int32_t) +FUNC_FUNC_3BUF(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF(land, int64_t, int64_t) +FUNC_FUNC_3BUF(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF(land, long, long) +FUNC_FUNC_3BUF(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC_3BUF(lor, int8_t, int8_t) +FUNC_FUNC_3BUF(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lor, int16_t, int16_t) +FUNC_FUNC_3BUF(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lor, int32_t, int32_t) +FUNC_FUNC_3BUF(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lor, int64_t, int64_t) +FUNC_FUNC_3BUF(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lor, long, long) +FUNC_FUNC_3BUF(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC_3BUF(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lxor, long, long) +FUNC_FUNC_3BUF(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC_3BUF(band, int8_t, int8_t) +FUNC_FUNC_3BUF(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF(band, int16_t, int16_t) +FUNC_FUNC_3BUF(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF(band, int32_t, int32_t) +FUNC_FUNC_3BUF(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF(band, int64_t, int64_t) +FUNC_FUNC_3BUF(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF(band, long, long) +FUNC_FUNC_3BUF(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC_3BUF(bor, int8_t, int8_t) +FUNC_FUNC_3BUF(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bor, int16_t, int16_t) +FUNC_FUNC_3BUF(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bor, int32_t, int32_t) +FUNC_FUNC_3BUF(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bor, int64_t, int64_t) +FUNC_FUNC_3BUF(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bor, long, long) +FUNC_FUNC_3BUF(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC_3BUF(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bxor, long, long) +FUNC_FUNC_3BUF(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_3BUF(maxloc, float_int, >) +LOC_FUNC_3BUF(maxloc, double_int, >) +LOC_FUNC_3BUF(maxloc, long_int, >) +LOC_FUNC_3BUF(maxloc, 2int, >) +LOC_FUNC_3BUF(maxloc, short_int, >) +LOC_FUNC_3BUF(maxloc, long_double_int, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_3BUF(minloc, float_int, <) +LOC_FUNC_3BUF(minloc, double_int, <) +LOC_FUNC_3BUF(minloc, long_int, <) +LOC_FUNC_3BUF(minloc, 2int, <) +LOC_FUNC_3BUF(minloc, short_int, <) +LOC_FUNC_3BUF(minloc, long_double_int, <) diff --git a/ompi/mca/op/rocm/op_rocm_impl.h b/ompi/mca/op/rocm/op_rocm_impl.h new file mode 100644 index 00000000000..9beec67d9ef --- /dev/null +++ b/ompi/mca/op/rocm/op_rocm_impl.h @@ -0,0 +1,906 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include + +#include +#include + +#ifndef BEGIN_C_DECLS +#if defined(c_plusplus) || defined(__cplusplus) +# define BEGIN_C_DECLS extern "C" { +# define END_C_DECLS } +#else +# define BEGIN_C_DECLS /* empty */ +# define END_C_DECLS /* empty */ +#endif +#endif + +BEGIN_C_DECLS + +#define OP_FUNC_SIG(name, type_name, type, op) \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + +#define FUNC_FUNC_SIG(name, type_name, type) \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_STRUCT(type_name, type1, type2) \ + typedef struct { \ + type1 v; \ + type2 k; \ + } ompi_op_predefined_##type_name##_t; + +#define LOC_FUNC_SIG(name, type_name, op) \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(max, int8_t, int8_t) +FUNC_FUNC_SIG(max, uint8_t, uint8_t) +FUNC_FUNC_SIG(max, int16_t, int16_t) +FUNC_FUNC_SIG(max, uint16_t, uint16_t) +FUNC_FUNC_SIG(max, int32_t, int32_t) +FUNC_FUNC_SIG(max, uint32_t, uint32_t) +FUNC_FUNC_SIG(max, int64_t, int64_t) +FUNC_FUNC_SIG(max, uint64_t, uint64_t) +FUNC_FUNC_SIG(max, long, long) +FUNC_FUNC_SIG(max, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_SIG(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_SIG(max, short_float, opal_short_float_t) +#endif +#endif // 0 + +FUNC_FUNC_SIG(max, float, float) +FUNC_FUNC_SIG(max, double, double) +FUNC_FUNC_SIG(max, long_double, long double) + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(min, int8_t, int8_t) +FUNC_FUNC_SIG(min, uint8_t, uint8_t) +FUNC_FUNC_SIG(min, int16_t, int16_t) +FUNC_FUNC_SIG(min, uint16_t, uint16_t) +FUNC_FUNC_SIG(min, int32_t, int32_t) +FUNC_FUNC_SIG(min, uint32_t, uint32_t) +FUNC_FUNC_SIG(min, int64_t, int64_t) +FUNC_FUNC_SIG(min, uint64_t, uint64_t) +FUNC_FUNC_SIG(min, long, long) +FUNC_FUNC_SIG(min, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_SIG(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_SIG(min, short_float, opal_short_float_t) +#endif +#endif // 0 + +FUNC_FUNC_SIG(min, float, float) +FUNC_FUNC_SIG(min, double, double) +FUNC_FUNC_SIG(min, long_double, long double) + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_SIG(sum, int8_t, int8_t, +=) +OP_FUNC_SIG(sum, uint8_t, uint8_t, +=) +OP_FUNC_SIG(sum, int16_t, int16_t, +=) +OP_FUNC_SIG(sum, uint16_t, uint16_t, +=) +OP_FUNC_SIG(sum, int32_t, int32_t, +=) +OP_FUNC_SIG(sum, uint32_t, uint32_t, +=) +OP_FUNC_SIG(sum, int64_t, int64_t, +=) +OP_FUNC_SIG(sum, uint64_t, uint64_t, +=) +OP_FUNC_SIG(sum, long, long, +=) +OP_FUNC_SIG(sum, ulong, unsigned long, +=) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_SIG(sum, short_float, short float, +=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_SIG(sum, short_float, opal_short_float_t, +=) +#endif +#endif // 0 + +OP_FUNC_SIG(sum, float, float, +=) +OP_FUNC_SIG(sum, double, double, +=) +OP_FUNC_SIG(sum, long_double, long double, +=) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_SIG(sum, c_short_float_complex, short float _Complex, +=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_SIG(sum, c_float_complex, hipFloatComplex) +FUNC_FUNC_SIG(sum, c_double_complex, hipDoubleComplex) +//OP_FUNC_SIG(sum, c_float_complex, float _Complex, +=) +//OP_FUNC_SIG(sum, c_double_complex, double _Complex, +=) +//OP_FUNC_SIG(sum, c_long_double_complex, long double _Complex, +=) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_SIG(prod, int8_t, int8_t, *=) +OP_FUNC_SIG(prod, uint8_t, uint8_t, *=) +OP_FUNC_SIG(prod, int16_t, int16_t, *=) +OP_FUNC_SIG(prod, uint16_t, uint16_t, *=) +OP_FUNC_SIG(prod, int32_t, int32_t, *=) +OP_FUNC_SIG(prod, uint32_t, uint32_t, *=) +OP_FUNC_SIG(prod, int64_t, int64_t, *=) +OP_FUNC_SIG(prod, uint64_t, uint64_t, *=) +OP_FUNC_SIG(prod, long, long, *=) +OP_FUNC_SIG(prod, ulong, unsigned long, *=) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_SIG(prod, short_float, short float, *=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_SIG(prod, short_float, opal_short_float_t, *=) +#endif +#endif // 0 + +OP_FUNC_SIG(prod, float, float, *=) +OP_FUNC_SIG(prod, double, double, *=) +OP_FUNC_SIG(prod, long_double, long double, *=) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_SIG(prod, c_short_float_complex, short float _Complex, *=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_SIG(prod, c_long_double_complex, long double _Complex, *=) +#endif // 0 +FUNC_FUNC_SIG(prod, c_float_complex, hipFloatComplex) +FUNC_FUNC_SIG(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC_SIG(land, int8_t, int8_t) +FUNC_FUNC_SIG(land, uint8_t, uint8_t) +FUNC_FUNC_SIG(land, int16_t, int16_t) +FUNC_FUNC_SIG(land, uint16_t, uint16_t) +FUNC_FUNC_SIG(land, int32_t, int32_t) +FUNC_FUNC_SIG(land, uint32_t, uint32_t) +FUNC_FUNC_SIG(land, int64_t, int64_t) +FUNC_FUNC_SIG(land, uint64_t, uint64_t) +FUNC_FUNC_SIG(land, long, long) +FUNC_FUNC_SIG(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC_SIG(lor, int8_t, int8_t) +FUNC_FUNC_SIG(lor, uint8_t, uint8_t) +FUNC_FUNC_SIG(lor, int16_t, int16_t) +FUNC_FUNC_SIG(lor, uint16_t, uint16_t) +FUNC_FUNC_SIG(lor, int32_t, int32_t) +FUNC_FUNC_SIG(lor, uint32_t, uint32_t) +FUNC_FUNC_SIG(lor, int64_t, int64_t) +FUNC_FUNC_SIG(lor, uint64_t, uint64_t) +FUNC_FUNC_SIG(lor, long, long) +FUNC_FUNC_SIG(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC_SIG(lxor, int8_t, int8_t) +FUNC_FUNC_SIG(lxor, uint8_t, uint8_t) +FUNC_FUNC_SIG(lxor, int16_t, int16_t) +FUNC_FUNC_SIG(lxor, uint16_t, uint16_t) +FUNC_FUNC_SIG(lxor, int32_t, int32_t) +FUNC_FUNC_SIG(lxor, uint32_t, uint32_t) +FUNC_FUNC_SIG(lxor, int64_t, int64_t) +FUNC_FUNC_SIG(lxor, uint64_t, uint64_t) +FUNC_FUNC_SIG(lxor, long, long) +FUNC_FUNC_SIG(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC_SIG(band, int8_t, int8_t) +FUNC_FUNC_SIG(band, uint8_t, uint8_t) +FUNC_FUNC_SIG(band, int16_t, int16_t) +FUNC_FUNC_SIG(band, uint16_t, uint16_t) +FUNC_FUNC_SIG(band, int32_t, int32_t) +FUNC_FUNC_SIG(band, uint32_t, uint32_t) +FUNC_FUNC_SIG(band, int64_t, int64_t) +FUNC_FUNC_SIG(band, uint64_t, uint64_t) +FUNC_FUNC_SIG(band, long, long) +FUNC_FUNC_SIG(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC_SIG(bor, int8_t, int8_t) +FUNC_FUNC_SIG(bor, uint8_t, uint8_t) +FUNC_FUNC_SIG(bor, int16_t, int16_t) +FUNC_FUNC_SIG(bor, uint16_t, uint16_t) +FUNC_FUNC_SIG(bor, int32_t, int32_t) +FUNC_FUNC_SIG(bor, uint32_t, uint32_t) +FUNC_FUNC_SIG(bor, int64_t, int64_t) +FUNC_FUNC_SIG(bor, uint64_t, uint64_t) +FUNC_FUNC_SIG(bor, long, long) +FUNC_FUNC_SIG(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC_SIG(bxor, int8_t, int8_t) +FUNC_FUNC_SIG(bxor, uint8_t, uint8_t) +FUNC_FUNC_SIG(bxor, int16_t, int16_t) +FUNC_FUNC_SIG(bxor, uint16_t, uint16_t) +FUNC_FUNC_SIG(bxor, int32_t, int32_t) +FUNC_FUNC_SIG(bxor, uint32_t, uint32_t) +FUNC_FUNC_SIG(bxor, int64_t, int64_t) +FUNC_FUNC_SIG(bxor, uint64_t, uint64_t) +FUNC_FUNC_SIG(bxor, long, long) +FUNC_FUNC_SIG(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(bxor, byte, char) + +/************************************************************************* + * Min and max location "pair" datatypes + *************************************************************************/ + +LOC_STRUCT(float_int, float, int) +LOC_STRUCT(double_int, double, int) +LOC_STRUCT(long_int, long, int) +LOC_STRUCT(2int, int, int) +LOC_STRUCT(short_int, short, int) +LOC_STRUCT(long_double_int, long double, int) +LOC_STRUCT(ulong, unsigned long, int) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_SIG(maxloc, float_int, >) +LOC_FUNC_SIG(maxloc, double_int, >) +LOC_FUNC_SIG(maxloc, long_int, >) +LOC_FUNC_SIG(maxloc, 2int, >) +LOC_FUNC_SIG(maxloc, short_int, >) +LOC_FUNC_SIG(maxloc, long_double_int, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_SIG(minloc, float_int, <) +LOC_FUNC_SIG(minloc, double_int, <) +LOC_FUNC_SIG(minloc, long_int, <) +LOC_FUNC_SIG(minloc, 2int, <) +LOC_FUNC_SIG(minloc, short_int, <) +LOC_FUNC_SIG(minloc, long_double_int, <) + + + +#define OP_FUNC_3BUF_SIG(name, type_name, type, op) \ + void ompi_op_rocm_3buff_##name##_##type_name##_submit(const type *in1, \ + const type *in2, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + +#define FUNC_FUNC_3BUF_SIG(name, type_name, type) \ + void ompi_op_rocm_3buff_##name##_##type_name##_submit(const type *in1, \ + const type *in2, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + +#define LOC_FUNC_3BUF_SIG(name, type_name, op) \ + void ompi_op_rocm_3buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a1, \ + const ompi_op_predefined_##type_name##_t *a2, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(max, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(max, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(max, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(max, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(max, long, long) +FUNC_FUNC_3BUF_SIG(max, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FUNC_FUNC_3BUF_SIG(max, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FUNC_FUNC_3BUF_SIG(max, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FUNC_FUNC_3BUF_SIG(max, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FUNC_FUNC_3BUF_SIG(max, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FUNC_FUNC_3BUF_SIG(max, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FUNC_FUNC_3BUF_SIG(max, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF_SIG(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF_SIG(max, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF_SIG(max, float, float) +FUNC_FUNC_3BUF_SIG(max, double, double) +FUNC_FUNC_3BUF_SIG(max, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FUNC_FUNC_3BUF_SIG(max, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FUNC_FUNC_3BUF_SIG(max, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FUNC_FUNC_3BUF_SIG(max, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FUNC_FUNC_3BUF_SIG(max, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FUNC_FUNC_3BUF_SIG(max, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FUNC_FUNC_3BUF_SIG(max, fortran_real16, ompi_fortran_real16_t) +#endif + + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(min, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(min, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(min, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(min, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(min, long, long) +FUNC_FUNC_3BUF_SIG(min, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FUNC_FUNC_3BUF_SIG(min, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FUNC_FUNC_3BUF_SIG(min, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FUNC_FUNC_3BUF_SIG(min, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FUNC_FUNC_3BUF_SIG(min, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FUNC_FUNC_3BUF_SIG(min, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FUNC_FUNC_3BUF_SIG(min, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF_SIG(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF_SIG(min, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF_SIG(min, float, float) +FUNC_FUNC_3BUF_SIG(min, double, double) +FUNC_FUNC_3BUF_SIG(min, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FUNC_FUNC_3BUF_SIG(min, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FUNC_FUNC_3BUF_SIG(min, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FUNC_FUNC_3BUF_SIG(min, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FUNC_FUNC_3BUF_SIG(min, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FUNC_FUNC_3BUF_SIG(min, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FUNC_FUNC_3BUF_SIG(min, fortran_real16, ompi_fortran_real16_t) +#endif + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF_SIG(sum, int8_t, int8_t, +) +OP_FUNC_3BUF_SIG(sum, uint8_t, uint8_t, +) +OP_FUNC_3BUF_SIG(sum, int16_t, int16_t, +) +OP_FUNC_3BUF_SIG(sum, uint16_t, uint16_t, +) +OP_FUNC_3BUF_SIG(sum, int32_t, int32_t, +) +OP_FUNC_3BUF_SIG(sum, uint32_t, uint32_t, +) +OP_FUNC_3BUF_SIG(sum, int64_t, int64_t, +) +OP_FUNC_3BUF_SIG(sum, uint64_t, uint64_t, +) +OP_FUNC_3BUF_SIG(sum, long, long, +) +OP_FUNC_3BUF_SIG(sum, ulong, unsigned long, +) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +OP_FUNC_3BUF_SIG(sum, fortran_integer, ompi_fortran_integer_t, +) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +OP_FUNC_3BUF_SIG(sum, fortran_integer1, ompi_fortran_integer1_t, +) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +OP_FUNC_3BUF_SIG(sum, fortran_integer2, ompi_fortran_integer2_t, +) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +OP_FUNC_3BUF_SIG(sum, fortran_integer4, ompi_fortran_integer4_t, +) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +OP_FUNC_3BUF_SIG(sum, fortran_integer8, ompi_fortran_integer8_t, +) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +OP_FUNC_3BUF_SIG(sum, fortran_integer16, ompi_fortran_integer16_t, +) +#endif +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF_SIG(sum, short_float, short float, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF_SIG(sum, short_float, opal_short_float_t, +) +#endif +OP_FUNC_3BUF_SIG(sum, float, float, +) +OP_FUNC_3BUF_SIG(sum, double, double, +) +OP_FUNC_3BUF_SIG(sum, long_double, long double, +) +#if OMPI_HAVE_FORTRAN_REAL +OP_FUNC_3BUF_SIG(sum, fortran_real, ompi_fortran_real_t, +) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +OP_FUNC_3BUF_SIG(sum, fortran_double_precision, ompi_fortran_double_precision_t, +) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +OP_FUNC_3BUF_SIG(sum, fortran_real2, ompi_fortran_real2_t, +) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +OP_FUNC_3BUF_SIG(sum, fortran_real4, ompi_fortran_real4_t, +) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +OP_FUNC_3BUF_SIG(sum, fortran_real8, ompi_fortran_real8_t, +) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +OP_FUNC_3BUF_SIG(sum, fortran_real16, ompi_fortran_real16_t, +) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF_SIG(sum, c_short_float_complex, short float _Complex, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF_SIG(sum, c_float_complex, float _Complex, +) +OP_FUNC_3BUF_SIG(sum, c_double_complex, double _Complex, +) +OP_FUNC_3BUF_SIG(sum, c_long_double_complex, long double _Complex, +) +#endif // 0 +FUNC_FUNC_3BUF_SIG(sum, c_float_complex, hipFloatComplex) +FUNC_FUNC_3BUF_SIG(sum, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF_SIG(prod, int8_t, int8_t, *) +OP_FUNC_3BUF_SIG(prod, uint8_t, uint8_t, *) +OP_FUNC_3BUF_SIG(prod, int16_t, int16_t, *) +OP_FUNC_3BUF_SIG(prod, uint16_t, uint16_t, *) +OP_FUNC_3BUF_SIG(prod, int32_t, int32_t, *) +OP_FUNC_3BUF_SIG(prod, uint32_t, uint32_t, *) +OP_FUNC_3BUF_SIG(prod, int64_t, int64_t, *) +OP_FUNC_3BUF_SIG(prod, uint64_t, uint64_t, *) +OP_FUNC_3BUF_SIG(prod, long, long, *) +OP_FUNC_3BUF_SIG(prod, ulong, unsigned long, *) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +OP_FUNC_3BUF_SIG(prod, fortran_integer, ompi_fortran_integer_t, *) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +OP_FUNC_3BUF_SIG(prod, fortran_integer1, ompi_fortran_integer1_t, *) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +OP_FUNC_3BUF_SIG(prod, fortran_integer2, ompi_fortran_integer2_t, *) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +OP_FUNC_3BUF_SIG(prod, fortran_integer4, ompi_fortran_integer4_t, *) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +OP_FUNC_3BUF_SIG(prod, fortran_integer8, ompi_fortran_integer8_t, *) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +OP_FUNC_3BUF_SIG(prod, fortran_integer16, ompi_fortran_integer16_t, *) +#endif +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF_SIG(prod, short_float, short float, *) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF_SIG(prod, short_float, opal_short_float_t, *) +#endif +OP_FUNC_3BUF_SIG(prod, float, float, *) +OP_FUNC_3BUF_SIG(prod, double, double, *) +OP_FUNC_3BUF_SIG(prod, long_double, long double, *) +#if OMPI_HAVE_FORTRAN_REAL +OP_FUNC_3BUF_SIG(prod, fortran_real, ompi_fortran_real_t, *) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +OP_FUNC_3BUF_SIG(prod, fortran_double_precision, ompi_fortran_double_precision_t, *) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +OP_FUNC_3BUF_SIG(prod, fortran_real2, ompi_fortran_real2_t, *) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +OP_FUNC_3BUF_SIG(prod, fortran_real4, ompi_fortran_real4_t, *) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +OP_FUNC_3BUF_SIG(prod, fortran_real8, ompi_fortran_real8_t, *) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +OP_FUNC_3BUF_SIG(prod, fortran_real16, ompi_fortran_real16_t, *) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF_SIG(prod, c_short_float_complex, short float _Complex, *) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF_SIG(prod, c_float_complex, float _Complex, *) +OP_FUNC_3BUF_SIG(prod, c_double_complex, double _Complex, *) +OP_FUNC_3BUF_SIG(prod, c_long_double_complex, long double _Complex, *) +#endif // 0 +FUNC_FUNC_3BUF_SIG(prod, c_float_complex, hipFloatComplex) +FUNC_FUNC_3BUF_SIG(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(land, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(land, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(land, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(land, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(land, long, long) +FUNC_FUNC_3BUF_SIG(land, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FUNC_FUNC_3BUF_SIG(land, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(lor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(lor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(lor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(lor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(lor, long, long) +FUNC_FUNC_3BUF_SIG(lor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FUNC_FUNC_3BUF_SIG(lor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(lxor, long, long) +FUNC_FUNC_3BUF_SIG(lxor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FUNC_FUNC_3BUF_SIG(lxor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(band, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(band, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(band, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(band, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(band, long, long) +FUNC_FUNC_3BUF_SIG(band, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FUNC_FUNC_3BUF_SIG(band, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FUNC_FUNC_3BUF_SIG(band, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FUNC_FUNC_3BUF_SIG(band, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FUNC_FUNC_3BUF_SIG(band, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FUNC_FUNC_3BUF_SIG(band, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FUNC_FUNC_3BUF_SIG(band, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC_3BUF_SIG(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(bor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(bor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(bor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(bor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(bor, long, long) +FUNC_FUNC_3BUF_SIG(bor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FUNC_FUNC_3BUF_SIG(bor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FUNC_FUNC_3BUF_SIG(bor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FUNC_FUNC_3BUF_SIG(bor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FUNC_FUNC_3BUF_SIG(bor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FUNC_FUNC_3BUF_SIG(bor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FUNC_FUNC_3BUF_SIG(bor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC_3BUF_SIG(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(bxor, long, long) +FUNC_FUNC_3BUF_SIG(bxor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FUNC_FUNC_3BUF_SIG(bxor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FUNC_FUNC_3BUF_SIG(bxor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FUNC_FUNC_3BUF_SIG(bxor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FUNC_FUNC_3BUF_SIG(bxor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FUNC_FUNC_3BUF_SIG(bxor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FUNC_FUNC_3BUF_SIG(bxor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC_3BUF_SIG(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +#if 0 +#if OMPI_HAVE_FORTRAN_REAL +LOC_FUNC_3BUF_SIG(maxloc, 2real, >) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +LOC_FUNC_3BUF_SIG(maxloc, 2double_precision, >) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER +LOC_FUNC_3BUF_SIG(maxloc, 2integer, >) +#endif +#endif // 0 +LOC_FUNC_3BUF_SIG(maxloc, float_int, >) +LOC_FUNC_3BUF_SIG(maxloc, double_int, >) +LOC_FUNC_3BUF_SIG(maxloc, long_int, >) +LOC_FUNC_3BUF_SIG(maxloc, 2int, >) +LOC_FUNC_3BUF_SIG(maxloc, short_int, >) +LOC_FUNC_3BUF_SIG(maxloc, long_double_int, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +#if 0 +#if OMPI_HAVE_FORTRAN_REAL +LOC_FUNC_3BUF_SIG(minloc, 2real, <) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +LOC_FUNC_3BUF_SIG(minloc, 2double_precision, <) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER +LOC_FUNC_3BUF_SIG(minloc, 2integer, <) +#endif +#endif // 0 +LOC_FUNC_3BUF_SIG(minloc, float_int, <) +LOC_FUNC_3BUF_SIG(minloc, double_int, <) +LOC_FUNC_3BUF_SIG(minloc, long_int, <) +LOC_FUNC_3BUF_SIG(minloc, 2int, <) +LOC_FUNC_3BUF_SIG(minloc, short_int, <) +LOC_FUNC_3BUF_SIG(minloc, long_double_int, <) + +END_C_DECLS diff --git a/ompi/mca/pml/ob1/pml_ob1_accelerator.c b/ompi/mca/pml/ob1/pml_ob1_accelerator.c index 737560db302..5526b3b3cbd 100644 --- a/ompi/mca/pml/ob1/pml_ob1_accelerator.c +++ b/ompi/mca/pml/ob1/pml_ob1_accelerator.c @@ -475,7 +475,7 @@ int mca_pml_ob1_accelerator_need_buffers(void * rreq, * best thing, but this may go away if CUDA IPC is supported everywhere in the * future. */ void mca_pml_ob1_accelerator_add_ipc_support(struct mca_btl_base_module_t* btl, int32_t flags, - ompi_proc_t* errproc, char* btlinfo) + ompi_proc_t* errproc, char* btlinfo) { mca_bml_base_endpoint_t* ep; int btl_verbose_stream = 0; @@ -502,6 +502,7 @@ void mca_pml_ob1_accelerator_add_ipc_support(struct mca_btl_base_module_t* btl, free(errhost); } ep->btl_send.bml_btls[i].btl_flags |= MCA_BTL_FLAGS_ACCELERATOR_GET; + break; } } } diff --git a/ompi/mca/pml/ob1/pml_ob1_recvfrag.c b/ompi/mca/pml/ob1/pml_ob1_recvfrag.c index a3db1458938..695da3e81ee 100644 --- a/ompi/mca/pml/ob1/pml_ob1_recvfrag.c +++ b/ompi/mca/pml/ob1/pml_ob1_recvfrag.c @@ -1266,7 +1266,7 @@ void mca_pml_ob1_recv_frag_callback_cid (mca_btl_base_module_t* btl, ob1_hdr_ntoh (hdr, hdr->hdr_common.hdr_type); - /* NTH: this should be ok as as all BTLs create a dummy segment */ + /* NTH: this should be ok as all BTLs create a dummy segment */ segments->seg_len -= offsetof (mca_pml_ob1_ext_match_hdr_t, hdr_match); segments->seg_addr.pval = (void *) hdr_match; diff --git a/ompi/mca/pml/ob1/pml_ob1_recvreq.c b/ompi/mca/pml/ob1/pml_ob1_recvreq.c index e67202868e8..8b87bf18256 100644 --- a/ompi/mca/pml/ob1/pml_ob1_recvreq.c +++ b/ompi/mca/pml/ob1/pml_ob1_recvreq.c @@ -393,7 +393,7 @@ static int mca_pml_ob1_recv_request_get_frag_failed (mca_pml_ob1_rdma_frag_t *fr /* tell peer to fall back on send for this region */ rc = mca_pml_ob1_recv_request_ack_send(NULL, proc, frag->rdma_hdr.hdr_rget.hdr_rndv.hdr_src_req.lval, - recvreq, frag->rdma_offset, frag->rdma_length, false); + recvreq, frag->rdma_offset, frag->rdma_length, true); MCA_PML_OB1_RDMA_FRAG_RETURN(frag); return rc; } diff --git a/ompi/op/Makefile.am b/ompi/op/Makefile.am index 5599c31311b..f0ba89c5618 100644 --- a/ompi/op/Makefile.am +++ b/ompi/op/Makefile.am @@ -22,6 +22,8 @@ # This makefile.am does not stand on its own - it is included from # ompi/Makefile.am +dist_ompidata_DATA += op/help-ompi-op.txt + headers += op/op.h lib@OMPI_LIBMPI_NAME@_la_SOURCES += op/op.c diff --git a/ompi/op/help-ompi-op.txt b/ompi/op/help-ompi-op.txt new file mode 100644 index 00000000000..5cfb60b8f9f --- /dev/null +++ b/ompi/op/help-ompi-op.txt @@ -0,0 +1,15 @@ +# -*- text -*- +# +# Copyright (c) 2004-2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# +# This is the US/English help file for Open MPI's allocator bucket support +# +[missing implementation] +ERROR: No suitable module for op %s on type %s found for device memory! diff --git a/ompi/op/op.c b/ompi/op/op.c index 3977fa8b97b..a75d6b33d5b 100644 --- a/ompi/op/op.c +++ b/ompi/op/op.c @@ -475,6 +475,7 @@ static void ompi_op_construct(ompi_op_t *new_op) new_op->o_3buff_intrinsic.fns[i] = NULL; new_op->o_3buff_intrinsic.modules[i] = NULL; } + new_op->o_device_op = NULL; } @@ -506,4 +507,19 @@ static void ompi_op_destruct(ompi_op_t *op) op->o_3buff_intrinsic.modules[i] = NULL; } } + + if (op->o_device_op != NULL) { + for (i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { + if( NULL != op->o_device_op->do_intrinsic.modules[i] ) { + OBJ_RELEASE(op->o_device_op->do_intrinsic.modules[i]); + op->o_device_op->do_intrinsic.modules[i] = NULL; + } + if( NULL != op->o_device_op->do_3buff_intrinsic.modules[i] ) { + OBJ_RELEASE(op->o_device_op->do_3buff_intrinsic.modules[i]); + op->o_device_op->do_3buff_intrinsic.modules[i] = NULL; + } + } + free(op->o_device_op); + op->o_device_op = NULL; + } } diff --git a/ompi/op/op.h b/ompi/op/op.h index 3aa95be7b90..150ae5ebc0e 100644 --- a/ompi/op/op.h +++ b/ompi/op/op.h @@ -3,7 +3,7 @@ * Copyright (c) 2004-2006 The Trustees of Indiana University and Indiana * University Research and Technology * Corporation. All rights reserved. - * Copyright (c) 2004-2007 The University of Tennessee and The University + * Copyright (c) 2004-2023 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. * Copyright (c) 2004-2007 High Performance Computing Center Stuttgart, @@ -44,6 +44,7 @@ #include "opal/class/opal_object.h" #include "opal/util/printf.h" +#include "opal/util/show_help.h" #include "ompi/datatype/ompi_datatype.h" #include "ompi/mpi/fortran/base/fint_2_int.h" @@ -122,6 +123,15 @@ enum ompi_op_type { OMPI_OP_REPLACE, OMPI_OP_NUM_OF_TYPES }; + +/* device op information */ +struct ompi_device_op_t { + opal_accelerator_stream_t *do_stream; + ompi_op_base_op_stream_fns_t do_intrinsic; + ompi_op_base_op_3buff_stream_fns_t do_3buff_intrinsic; +}; +typedef struct ompi_device_op_t ompi_device_op_t; + /** * Back-end type of MPI_Op */ @@ -167,6 +177,10 @@ struct ompi_op_t { /** 3-buffer functions, which is only for intrinsic ops. No need for the C/C++/Fortran user-defined functions. */ ompi_op_base_op_3buff_fns_t o_3buff_intrinsic; + + /** device functions, only for intrinsic ops. + Provided if device support is detected. */ + ompi_device_op_t *o_device_op; }; /** @@ -376,7 +390,7 @@ OMPI_DECLSPEC void ompi_op_set_java_callback(ompi_op_t *op, void *jnienv, * this function is provided to hide the internal structure field * names. */ -static inline bool ompi_op_is_intrinsic(ompi_op_t * op) +static inline bool ompi_op_is_intrinsic(const ompi_op_t * op) { return (bool) (0 != (op->o_flags & OMPI_OP_FLAGS_INTRINSIC)); } @@ -469,7 +483,6 @@ static inline bool ompi_op_is_valid(ompi_op_t * op, ompi_datatype_t * ddt, return true; } - /** * Perform a reduction operation. * @@ -500,9 +513,11 @@ static inline bool ompi_op_is_valid(ompi_op_t * op, ompi_datatype_t * ddt, * optimization). If you give it an intrinsic op with a datatype that * is not defined to have that operation, it is likely to seg fault. */ -static inline void ompi_op_reduce(ompi_op_t * op, void *source, - void *target, size_t full_count, - ompi_datatype_t * dtype) +static inline void ompi_op_reduce_stream(ompi_op_t * op, const void *source, + void *target, size_t full_count, + ompi_datatype_t * dtype, + int device, + opal_accelerator_stream_t *stream) { MPI_Fint f_dtype, f_count; int count = full_count; @@ -531,7 +546,7 @@ static inline void ompi_op_reduce(ompi_op_t * op, void *source, } shift = done_count * ext; // Recurse one level in iterations of 'int' - ompi_op_reduce(op, (char*)source + shift, (char*)target + shift, iter_count, dtype); + ompi_op_reduce_stream(op, (char*)source + shift, (char*)target + shift, iter_count, dtype, device, stream); done_count += iter_count; } return; @@ -560,6 +575,44 @@ static inline void ompi_op_reduce(ompi_op_t * op, void *source, * :-) */ + bool use_device_op = false; + /* check if either of the buffers is on a device and if so make sure we can + * access handle it properly */ + if (device != MCA_ACCELERATOR_NO_DEVICE_ID && + ompi_datatype_is_predefined(dtype) && + 0 != (op->o_flags & OMPI_OP_FLAGS_INTRINSIC) && + NULL != op->o_device_op) { + use_device_op = true; + } + + if (!use_device_op) { + /* query the accelerator for whether we can still execute */ + int source_dev_id, target_dev_id; + uint64_t source_flags, target_flags; + int target_check_addr = opal_accelerator.check_addr(target, &target_dev_id, &target_flags); + int source_check_addr = opal_accelerator.check_addr(source, &source_dev_id, &source_flags); + if (target_check_addr > 0 && + source_check_addr > 0 && + ompi_datatype_is_predefined(dtype) && + 0 != (op->o_flags & OMPI_OP_FLAGS_INTRINSIC) && + NULL != op->o_device_op) { + use_device_op = true; + if (target_dev_id == source_dev_id) { + /* both inputs are on the same device; if not the op will take of that */ + device = target_dev_id; + } + } else { + /* check whether we can access the memory from the host */ + if ((target_check_addr == 0 || (target_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY)) && + (source_check_addr == 0 || (source_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY))) { + /* nothing to be done, we won't need device-capable ops */ + } else { + opal_show_help("help-ompi-op.txt", "missing implementation", true, op->o_name, dtype->name); + abort(); + } + } + } + /* For intrinsics, we also pass the corresponding op module */ if (0 != (op->o_flags & OMPI_OP_FLAGS_INTRINSIC)) { int dtype_id; @@ -569,9 +622,28 @@ static inline void ompi_op_reduce(ompi_op_t * op, void *source, } else { dtype_id = ompi_op_ddt_map[dtype->id]; } - op->o_func.intrinsic.fns[dtype_id](source, target, - &count, &dtype, - op->o_func.intrinsic.modules[dtype_id]); + if (use_device_op) { + if (NULL == op->o_device_op) { + fprintf(stderr, "no suitable device op module found!"); + abort(); // TODO: be more graceful! + } + opal_accelerator_stream_t *actual_stream = stream; + bool flush_stream = false; + if (NULL == stream) { + opal_accelerator.get_default_stream(device, &actual_stream); + flush_stream = true; + } + op->o_device_op->do_intrinsic.fns[dtype_id]((void*)source, target, + &count, &dtype, device, actual_stream, + op->o_device_op->do_intrinsic.modules[dtype_id]); + if (flush_stream) { + opal_accelerator.wait_stream(actual_stream); + } + } else { + op->o_func.intrinsic.fns[dtype_id]((void*)source, target, + &count, &dtype, + op->o_func.intrinsic.modules[dtype_id]); + } return; } @@ -579,24 +651,31 @@ static inline void ompi_op_reduce(ompi_op_t * op, void *source, if (0 != (op->o_flags & OMPI_OP_FLAGS_FORTRAN_FUNC)) { f_dtype = OMPI_INT_2_FINT(dtype->d_f_to_c_index); f_count = OMPI_INT_2_FINT(count); - op->o_func.fort_fn(source, target, &f_count, &f_dtype); + op->o_func.fort_fn((void*)source, target, &f_count, &f_dtype); return; } else if (0 != (op->o_flags & OMPI_OP_FLAGS_JAVA_FUNC)) { - op->o_func.java_data.intercept_fn(source, target, &count, &dtype, + op->o_func.java_data.intercept_fn((void*)source, target, &count, &dtype, op->o_func.java_data.baseType, op->o_func.java_data.jnienv, op->o_func.java_data.object); return; } - op->o_func.c_fn(source, target, &count, &dtype); + op->o_func.c_fn((void*)source, target, &count, &dtype); return; } -static inline void ompi_3buff_op_user (ompi_op_t *op, void * restrict source1, void * restrict source2, +static inline void ompi_op_reduce(ompi_op_t * op, const void *source, + void *target, size_t full_count, + ompi_datatype_t * dtype) +{ + ompi_op_reduce_stream(op, source, target, full_count, dtype, MCA_ACCELERATOR_NO_DEVICE_ID, NULL); +} + +static inline void ompi_3buff_op_user (ompi_op_t *op, const void * source1, const void * source2, void * restrict result, int count, struct ompi_datatype_t *dtype) { - ompi_datatype_copy_content_same_ddt (dtype, count, result, source1); - op->o_func.c_fn (source2, result, &count, &dtype); + ompi_datatype_copy_content_same_ddt (dtype, count, result, (void*)source1); + op->o_func.c_fn ((void*)source2, result, &count, &dtype); } /** @@ -622,24 +701,135 @@ static inline void ompi_3buff_op_user (ompi_op_t *op, void * restrict source1, v * * Otherwise, this function is the same as ompi_op_reduce. */ -static inline void ompi_3buff_op_reduce(ompi_op_t * op, void *source1, - void *source2, void *target, - int count, ompi_datatype_t * dtype) +static inline void ompi_3buff_op_reduce_stream(ompi_op_t * op, const void *source1, + const void *source2, void *target, + int count, ompi_datatype_t * dtype, + int device, + opal_accelerator_stream_t *stream) { - void *restrict src1; - void *restrict src2; - void *restrict tgt; - src1 = source1; - src2 = source2; - tgt = target; + bool use_device_op = false; + if (OPAL_UNLIKELY(!ompi_op_is_intrinsic (op))) { + /* no 3buff variants for user-defined ops */ + ompi_3buff_op_user (op, source1, source2, target, count, dtype); + return; + } + + if (device != MCA_ACCELERATOR_NO_DEVICE_ID && + ompi_datatype_is_predefined(dtype) && + op->o_flags & OMPI_OP_FLAGS_INTRINSIC && + NULL != op->o_device_op) { + use_device_op = true; + } + + if (!use_device_op) { + int source1_dev_id, source2_dev_id, target_dev_id; + uint64_t source1_flags, source2_flags, target_flags; + int target_check_addr = opal_accelerator.check_addr(target, &target_dev_id, &target_flags); + int source1_check_addr = opal_accelerator.check_addr(source1, &source1_dev_id, &source1_flags); + int source2_check_addr = opal_accelerator.check_addr(source2, &source2_dev_id, &source2_flags); + /* check if either of the buffers is on a device and if so make sure we can + * access handle it properly */ + if (target_check_addr > 0 || source1_check_addr > 0 || source2_check_addr > 0) { + if (ompi_datatype_is_predefined(dtype) && + op->o_flags & OMPI_OP_FLAGS_INTRINSIC && + NULL != op->o_device_op) { + use_device_op = true; + device = target_dev_id; + } else { + /* check whether we can access the memory from the host */ + if ((target_check_addr == 0 || (target_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY)) && + (source1_check_addr == 0 || (source1_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY)) && + (source2_check_addr == 0 || (source2_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY))) { + /* nothing to be done, we won't need device-capable ops */ + } else { + fprintf(stderr, "3buff op: no suitable op module found for device memory!\n"); + abort(); + } + } + } + } + + /* For intrinsics, we also pass the corresponding op module */ + if (0 != (op->o_flags & OMPI_OP_FLAGS_INTRINSIC)) { + int dtype_id; + if (!ompi_datatype_is_predefined(dtype)) { + ompi_datatype_t *dt = ompi_datatype_get_single_predefined_type_from_args(dtype); + dtype_id = ompi_op_ddt_map[dt->id]; + } else { + dtype_id = ompi_op_ddt_map[dtype->id]; + } + if (use_device_op) { + opal_accelerator_stream_t *actual_stream = stream; + bool flush_stream = false; + if (NULL == stream) { + opal_accelerator.get_default_stream(device, &actual_stream); + flush_stream = true; + } + op->o_device_op->do_3buff_intrinsic.fns[dtype_id]((void*)source1, (void*)source2, target, + &count, &dtype, device, actual_stream, + op->o_device_op->do_3buff_intrinsic.modules[dtype_id]); + if (flush_stream) { + opal_accelerator.wait_stream(actual_stream); + } + } else { + op->o_3buff_intrinsic.fns[dtype_id]((void*)source1, (void*)source2, target, + &count, &dtype, + op->o_func.intrinsic.modules[dtype_id]); + } + } +} + +static inline void ompi_3buff_op_reduce(ompi_op_t * op, const void *source1, + const void *source2, void *target, + int count, ompi_datatype_t * dtype) +{ if (OPAL_LIKELY(ompi_op_is_intrinsic (op))) { - op->o_3buff_intrinsic.fns[ompi_op_ddt_map[dtype->id]](src1, src2, - tgt, &count, - &dtype, - op->o_3buff_intrinsic.modules[ompi_op_ddt_map[dtype->id]]); + ompi_3buff_op_reduce_stream(op, source1, source2, target, count, dtype, MCA_ACCELERATOR_NO_DEVICE_ID, NULL); } else { - ompi_3buff_op_user (op, src1, src2, tgt, count, dtype); + ompi_3buff_op_user (op, source1, source2, target, count, dtype); + } +} + +static inline void ompi_op_preferred_device(ompi_op_t *op, int source_dev, + int target_dev, size_t count, + ompi_datatype_t *dtype, int *op_device) +{ + /* default to host */ + *op_device = -1; + if (!ompi_op_is_intrinsic (op)) { + return; + } + /* quick check: can we execute on the device? */ + int dtype_id = ompi_op_ddt_map[dtype->id]; + if (NULL == op->o_device_op || NULL == op->o_device_op->do_intrinsic.fns[dtype_id]) { + /* not available on the gpu, must select host */ + return; + } + + size_t size_type; + ompi_datatype_type_size(dtype, &size_type); + + float device_bw; + if (target_dev >= 0) { + opal_accelerator.get_mem_bw(target_dev, &device_bw); + } else if (source_dev >= 0) { + opal_accelerator.get_mem_bw(source_dev, &device_bw); + } + + // assume we reach 50% of theoretical peak on the device + device_bw /= 2.0; + + // TODO: determine at runtime (?) + const float host_bw = 10.0; // 10GB/s + + float host_startup_cost = 0.0; // host has no startup cost + float host_compute_cost = (count*size_type) / (host_bw*1024); // assume 10GB/s memory bandwidth on host + float device_startup_cost = 10.0; // 10us startup cost on device + float device_compute_cost = (count*size_type) / (device_bw*1024); + + if ((host_startup_cost + host_compute_cost) > (device_startup_cost + device_compute_cost)) { + *op_device = (target_dev >= 0) ? target_dev : source_dev; } } diff --git a/opal/datatype/opal_convertor.c b/opal/datatype/opal_convertor.c index 8550683a60d..24d89b0a3ba 100644 --- a/opal/datatype/opal_convertor.c +++ b/opal/datatype/opal_convertor.c @@ -50,6 +50,8 @@ static void *opal_convertor_accelerator_memcpy(void *dest, const void *src, size int res; if (!(convertor->flags & CONVERTOR_ACCELERATOR)) { return MEMCPY(dest, src, size); + } else if (convertor->flags & CONVERTOR_ACCELERATOR_UNIFIED) { + return MEMCPY(dest, src, size); } res = opal_accelerator.mem_copy(MCA_ACCELERATOR_NO_DEVICE_ID, MCA_ACCELERATOR_NO_DEVICE_ID, diff --git a/opal/datatype/opal_datatype.h b/opal/datatype/opal_datatype.h index 5f7fc53fa7d..375b0475fef 100644 --- a/opal/datatype/opal_datatype.h +++ b/opal/datatype/opal_datatype.h @@ -42,6 +42,7 @@ #include #include "opal/class/opal_object.h" +#include "opal/mca/accelerator/accelerator.h" BEGIN_C_DECLS @@ -309,6 +310,10 @@ OPAL_DECLSPEC int32_t opal_datatype_copy_content_same_ddt(const opal_datatype_t int32_t count, char *pDestBuf, char *pSrcBuf); +OPAL_DECLSPEC int32_t opal_datatype_copy_content_same_ddt_stream(const opal_datatype_t *datatype, int32_t count, + char *destination_base, char *source_base, + opal_accelerator_stream_t *stream); + OPAL_DECLSPEC int opal_datatype_compute_ptypes(opal_datatype_t *datatype); /* diff --git a/opal/datatype/opal_datatype_copy.c b/opal/datatype/opal_datatype_copy.c index e10ea97b1bb..1459b6ad558 100644 --- a/opal/datatype/opal_datatype_copy.c +++ b/opal/datatype/opal_datatype_copy.c @@ -28,6 +28,7 @@ #include #include +#include #include "opal/datatype/opal_convertor.h" #include "opal/datatype/opal_datatype.h" @@ -55,10 +56,29 @@ } \ } while (0) -static void *opal_datatype_accelerator_memcpy(void *dest, const void *src, size_t size) + +static opal_accelerator_transfer_type_t get_transfer_type(int src_dev, int dst_dev) +{ + if (src_dev == MCA_ACCELERATOR_NO_DEVICE_ID) { + if (dst_dev == MCA_ACCELERATOR_NO_DEVICE_ID) { + return MCA_ACCELERATOR_TRANSFER_HTOH; + } else { + return MCA_ACCELERATOR_TRANSFER_HTOD; + } + } else { + if (dst_dev == MCA_ACCELERATOR_NO_DEVICE_ID) { + return MCA_ACCELERATOR_TRANSFER_DTOH; + } else { + return MCA_ACCELERATOR_TRANSFER_DTOD; + } + } +} + +static void *opal_datatype_accelerator_memcpy(void *dest, const void *src, size_t size, + opal_accelerator_stream_t *stream) { int res; - int dev_id; + int src_dev_id = MCA_ACCELERATOR_NO_DEVICE_ID, dst_dev_id = MCA_ACCELERATOR_NO_DEVICE_ID; uint64_t flags; /* If accelerator check addr returns an error, we can only * assume it is a host buffer. If device buffer checking fails, @@ -67,12 +87,18 @@ static void *opal_datatype_accelerator_memcpy(void *dest, const void *src, size_ * and retries are also unlikely to succeed. We identify these * buffers as host buffers as attempting a memcpy would provide * a chance to succeed. */ - if (0 >= opal_accelerator.check_addr(dest, &dev_id, &flags) && - 0 >= opal_accelerator.check_addr(src, &dev_id, &flags)) { + if (0 >= opal_accelerator.check_addr(dest, &dst_dev_id, &flags) && + 0 >= opal_accelerator.check_addr(src, &src_dev_id, &flags)) { return memcpy(dest, src, size); } - res = opal_accelerator.mem_copy(MCA_ACCELERATOR_NO_DEVICE_ID, MCA_ACCELERATOR_NO_DEVICE_ID, - dest, src, size, MCA_ACCELERATOR_TRANSFER_UNSPEC); + //printf("opal_datatype_accelerator_memcpy: dst %p dev %d src %p dev %d transer_type %d\n", dest, dst_dev_id, src, src_dev_id, get_transfer_type(src_dev_id, dst_dev_id)); + if (NULL != stream) { + res = opal_accelerator.mem_copy_async(dst_dev_id, src_dev_id, + dest, src, size, stream, get_transfer_type(src_dev_id, dst_dev_id)); + } else { + res = opal_accelerator.mem_copy(dst_dev_id, src_dev_id, + dest, src, size, get_transfer_type(src_dev_id, dst_dev_id)); + } if (OPAL_SUCCESS != res) { opal_output(0, "Error in accelerator memcpy"); abort(); @@ -80,7 +106,8 @@ static void *opal_datatype_accelerator_memcpy(void *dest, const void *src, size_ return dest; } -static void *opal_datatype_accelerator_memmove(void *dest, const void *src, size_t size) +static void *opal_datatype_accelerator_memmove(void *dest, const void *src, size_t size, + opal_accelerator_stream_t *stream) { int res; int dev_id; @@ -96,8 +123,13 @@ static void *opal_datatype_accelerator_memmove(void *dest, const void *src, size 0 >= opal_accelerator.check_addr(src, &dev_id, &flags)) { return memmove(dest, src, size); } - res = opal_accelerator.mem_move(MCA_ACCELERATOR_NO_DEVICE_ID, MCA_ACCELERATOR_NO_DEVICE_ID, - dest, src, size, MCA_ACCELERATOR_TRANSFER_UNSPEC); + if (NULL == stream) { + res = opal_accelerator.mem_move(MCA_ACCELERATOR_NO_DEVICE_ID, MCA_ACCELERATOR_NO_DEVICE_ID, + dest, src, size, MCA_ACCELERATOR_TRANSFER_UNSPEC); + } else { + res = opal_accelerator.mem_move_async(MCA_ACCELERATOR_NO_DEVICE_ID, MCA_ACCELERATOR_NO_DEVICE_ID, + dest, src, size, stream, MCA_ACCELERATOR_TRANSFER_UNSPEC); + } if (OPAL_SUCCESS != res) { opal_output(0, "Error in accelerator memmove"); abort(); @@ -121,11 +153,12 @@ static void *opal_datatype_accelerator_memmove(void *dest, const void *src, size #define MEM_OP opal_datatype_accelerator_memmove #include "opal_datatype_copy.h" -int32_t opal_datatype_copy_content_same_ddt(const opal_datatype_t *datatype, int32_t count, - char *destination_base, char *source_base) +int32_t opal_datatype_copy_content_same_ddt_stream(const opal_datatype_t *datatype, int32_t count, + char *destination_base, char *source_base, + opal_accelerator_stream_t *stream) { ptrdiff_t extent; - int32_t (*fct)(const opal_datatype_t *, int32_t, char *, char *); + int32_t (*fct)(const opal_datatype_t *, int32_t, char *, char *, opal_accelerator_stream_t*); DO_DEBUG(opal_output(0, "opal_datatype_copy_content_same_ddt( %p, %d, dst %p, src %p )\n", (void *) datatype, count, (void *) destination_base, @@ -157,5 +190,11 @@ int32_t opal_datatype_copy_content_same_ddt(const opal_datatype_t *datatype, int fct = overlap_accelerator_copy_content_same_ddt; } } - return fct(datatype, count, destination_base, source_base); + return fct(datatype, count, destination_base, source_base, stream); +} + +int32_t opal_datatype_copy_content_same_ddt(const opal_datatype_t *datatype, int32_t count, + char *destination_base, char *source_base) +{ + return opal_datatype_copy_content_same_ddt_stream(datatype, count, destination_base, source_base, NULL); } diff --git a/opal/datatype/opal_datatype_copy.h b/opal/datatype/opal_datatype_copy.h index 1e10b03ed27..dba8de0baf1 100644 --- a/opal/datatype/opal_datatype_copy.h +++ b/opal/datatype/opal_datatype_copy.h @@ -44,7 +44,7 @@ static inline void _predefined_data(const dt_elem_desc_t *ELEM, const opal_datatype_t *DATATYPE, unsigned char *SOURCE_BASE, size_t TOTAL_COUNT, size_t COUNT, unsigned char *SOURCE, unsigned char *DESTINATION, - size_t *SPACE) + size_t *SPACE, opal_accelerator_stream_t *stream) { const ddt_elem_desc_t *_elem = &((ELEM)->elem); unsigned char *_source = (SOURCE) + _elem->disp; @@ -69,7 +69,7 @@ static inline void _predefined_data(const dt_elem_desc_t *ELEM, const opal_datat DO_DEBUG(opal_output(0, "copy %s( %p, %p, %" PRIsize_t " ) => space %" PRIsize_t "\n", STRINGIFY(MEM_OP_NAME), (void *) _destination, (void *) _source, do_now_bytes, *(SPACE) -_i * do_now_bytes);); - MEM_OP(_destination, _source, do_now_bytes); + MEM_OP(_destination, _source, do_now_bytes, stream); _destination += _elem->extent; _source += _elem->extent; } @@ -79,7 +79,7 @@ static inline void _predefined_data(const dt_elem_desc_t *ELEM, const opal_datat static inline void _contiguous_loop(const dt_elem_desc_t *ELEM, const opal_datatype_t *DATATYPE, unsigned char *SOURCE_BASE, size_t TOTAL_COUNT, size_t COUNT, unsigned char *SOURCE, unsigned char *DESTINATION, - size_t *SPACE) + size_t *SPACE, opal_accelerator_stream_t *stream) { ddt_loop_desc_t *_loop = (ddt_loop_desc_t *) (ELEM); ddt_endloop_desc_t *_end_loop = (ddt_endloop_desc_t *) ((ELEM) + _loop->items); @@ -91,7 +91,7 @@ static inline void _contiguous_loop(const dt_elem_desc_t *ELEM, const opal_datat _copy_loops *= _end_loop->size; OPAL_DATATYPE_SAFEGUARD_POINTER(_source, _copy_loops, (SOURCE_BASE), (DATATYPE), (TOTAL_COUNT)); - MEM_OP(_destination, _source, _copy_loops); + MEM_OP(_destination, _source, _copy_loops, stream); } else { for (size_t _i = 0; _i < _copy_loops; _i++) { OPAL_DATATYPE_SAFEGUARD_POINTER(_source, _end_loop->size, (SOURCE_BASE), (DATATYPE), @@ -100,7 +100,7 @@ static inline void _contiguous_loop(const dt_elem_desc_t *ELEM, const opal_datat "copy 3. %s( %p, %p, %" PRIsize_t " ) => space %" PRIsize_t "\n", STRINGIFY(MEM_OP_NAME), (void *) _destination, (void *) _source, _end_loop->size, *(SPACE) -_i * _end_loop->size);); - MEM_OP(_destination, _source, _end_loop->size); + MEM_OP(_destination, _source, _end_loop->size, stream); _source += _loop->extent; _destination += _loop->extent; } @@ -110,7 +110,8 @@ static inline void _contiguous_loop(const dt_elem_desc_t *ELEM, const opal_datat } static inline int32_t _copy_content_same_ddt(const opal_datatype_t *datatype, int32_t count, - char *destination_base, char *source_base) + char *destination_base, char *source_base, + opal_accelerator_stream_t *stream) { dt_stack_t *pStack; /* pointer to the position on the stack */ int32_t stack_pos; /* index of the stack level */ @@ -148,13 +149,20 @@ static inline int32_t _copy_content_same_ddt(const opal_datatype_t *datatype, in DO_DEBUG(opal_output(0, "copy c1. %s( %p, %p, %lu ) => space %lu\n", STRINGIFY(MEM_OP_NAME), (void *) destination, (void *) source, (unsigned long) memop_chunk, (unsigned long) total_length);); - MEM_OP(destination, source, memop_chunk); + MEM_OP(destination, source, memop_chunk, stream); destination += memop_chunk; source += memop_chunk; total_length -= memop_chunk; } return 0; /* completed */ } + opal_accelerator_stream_t *actual_stream = stream; + bool flush_stream = false; + if (NULL == stream) { + /* TODO: figure out the stream */ + opal_accelerator.get_default_stream(0, &actual_stream); + flush_stream = true; + } for (pos_desc = 0; (int32_t) pos_desc < count; pos_desc++) { OPAL_DATATYPE_SAFEGUARD_POINTER(destination, datatype->size, (unsigned char *) destination_base, datatype, count); @@ -164,10 +172,13 @@ static inline int32_t _copy_content_same_ddt(const opal_datatype_t *datatype, in STRINGIFY(MEM_OP_NAME), (void *) destination, (void *) source, (unsigned long) datatype->size, (unsigned long) (iov_len_local - (pos_desc * datatype->size)));); - MEM_OP(destination, source, datatype->size); + MEM_OP(destination, source, datatype->size, actual_stream); destination += extent; source += extent; } + if (flush_stream) { + opal_accelerator.wait_stream(actual_stream); + } return 0; /* completed */ } @@ -185,11 +196,18 @@ static inline int32_t _copy_content_same_ddt(const opal_datatype_t *datatype, in UPDATE_INTERNAL_COUNTERS(description, 0, pElem, count_desc); + opal_accelerator_stream_t *actual_stream = stream; + bool flush_stream = false; + if (NULL == stream) { + /* TODO: figure out the stream */ + opal_accelerator.get_default_stream(0, &actual_stream); + flush_stream = true; + } while (1) { while (OPAL_LIKELY(pElem->elem.common.flags & OPAL_DATATYPE_FLAG_DATA)) { /* now here we have a basic datatype */ _predefined_data(pElem, datatype, (unsigned char *) source_base, count, count_desc, - source, destination, &iov_len_local); + source, destination, &iov_len_local, stream); pos_desc++; /* advance to the next data */ UPDATE_INTERNAL_COUNTERS(description, pos_desc, pElem, count_desc); } @@ -202,6 +220,9 @@ static inline int32_t _copy_content_same_ddt(const opal_datatype_t *datatype, in if (--(pStack->count) == 0) { /* end of loop */ if (stack_pos == 0) { assert(iov_len_local == 0); + if (flush_stream) { + opal_accelerator.wait_stream(actual_stream); + } return 0; /* completed */ } stack_pos--; @@ -229,7 +250,7 @@ static inline int32_t _copy_content_same_ddt(const opal_datatype_t *datatype, in ptrdiff_t local_disp = (ptrdiff_t) source; if (pElem->loop.common.flags & OPAL_DATATYPE_FLAG_CONTIGUOUS) { _contiguous_loop(pElem, datatype, (unsigned char *) source_base, count, count_desc, - source, destination, &iov_len_local); + source, destination, &iov_len_local, actual_stream); pos_desc += pElem->loop.items + 1; goto update_loop_description; } diff --git a/opal/mca/accelerator/accelerator.h b/opal/mca/accelerator/accelerator.h index efc951377ca..d49fd8077d8 100644 --- a/opal/mca/accelerator/accelerator.h +++ b/opal/mca/accelerator/accelerator.h @@ -129,6 +129,24 @@ struct opal_accelerator_event_t { typedef struct opal_accelerator_event_t opal_accelerator_event_t; OBJ_CLASS_DECLARATION(opal_accelerator_event_t); +struct opal_accelerator_mempool_t { + opal_object_t super; + /* Memory pool object */ + void *mempool; +}; +typedef struct opal_accelerator_event_t opal_accelerator_mempool_t; +OBJ_CLASS_DECLARATION(opal_accelerator_mempool_t); + +/** + * Query the default stream. + * + * @param[OUT] stream Set to the default stream. + * + * @return OPAL_SUCCESS or error status on failure + */ +typedef int (*opal_accelerator_base_get_default_stream_fn_t)( + int dev_id, opal_accelerator_stream_t **stream); + /** * Check whether a pointer belongs to an accelerator or not. * interfaces @@ -273,6 +291,29 @@ typedef int (*opal_accelerator_base_module_memmove_fn_t)( int dest_dev_id, int src_dev_id, void *dest, const void *src, size_t size, opal_accelerator_transfer_type_t type); +/** + * Copies memory asynchronously from src to dest. Memory of dest and src + * may overlap. Optionally can specify the transfer type to + * avoid pointer detection for performance. The operations will be enqueued + * into the provided stream but are not guaranteed to be complete upon return. + * + * @param[IN] dest_dev_id Associated device to copy to or + * MCA_ACCELERATOR_NO_DEVICE_ID + * @param[IN] src_dev_id Associated device to copy from or + * MCA_ACCELERATOR_NO_DEVICE_ID + * @param[IN] dest Destination to copy memory to + * @param[IN] src Source to copy memory from + * @param[IN] size Size of memory to copy + * @param[IN] stream Stream to perform asynchronous move on + * @param[IN] type Transfer type field for performance + * Can be set to MCA_ACCELERATOR_TRANSFER_UNSPEC + * if caller is unsure of the transfer direction. + * + * @return OPAL_SUCCESS or error status on failure + */ +typedef int (*opal_accelerator_base_module_memmove_async_fn_t)( + int dest_dev_id, int src_dev_id, void *dest, const void *src, size_t size, + opal_accelerator_stream_t *stream, opal_accelerator_transfer_type_t type); /** * Allocates size bytes memory from the device and sets ptr to the * pointer of the allocated memory. The memory is not initialized. @@ -303,6 +344,44 @@ typedef int (*opal_accelerator_base_module_mem_alloc_fn_t)( typedef int (*opal_accelerator_base_module_mem_release_fn_t)( int dev_id, void *ptr); +/** + * Allocates size bytes memory from the device and sets ptr to the + * pointer of the allocated memory. The memory is not initialized. + * The allocation request is placed into the stream object. + * Any use of the memory must succeed the completion of this + * operation on the stream. + * + * @param[IN] dev_id Associated device for the allocation or + * MCA_ACCELERATOR_NO_DEVICE_ID + * @param[OUT] ptr Returns pointer to allocated memory + * @param[IN] size Size of memory to allocate + * @param[IN] stream Stream into which to insert the allocation request + * + * @return OPAL_SUCCESS or error status on failure + */ +typedef int (*opal_accelerator_base_module_mem_alloc_stream_fn_t)( + int dev_id, void **ptr, size_t size, opal_accelerator_stream_t *stream); + +/** + * Frees the memory space pointed to by ptr which has been returned by + * a previous call to an opal_accelerator_base_module_mem_alloc_stream_fn_t(). + * If the function is called on a ptr that has already been freed, + * undefined behavior occurs. If ptr is NULL, no operation is performed, + * and the function returns OPAL_SUCCESS. + * The release of the memory will be inserted into the stream and occurs after + * all previous operations have completed. + * + * @param[IN] dev_id Associated device for the allocation or + * MCA_ACCELERATOR_NO_DEVICE_ID + * @param[IN] ptr Pointer to free + * @param[IN] stream Stream into which to insert the free operation + * + * @return OPAL_SUCCESS or error status on failure + */ +typedef int (*opal_accelerator_base_module_mem_release_stream_fn_t)( + int dev_id, void *ptr, opal_accelerator_stream_t *stream); + + /** * Retrieves the base address and/or size of a memory allocation of the * device. @@ -394,10 +473,25 @@ typedef int (*opal_accelerator_base_module_device_can_access_peer_fn_t)( typedef int (*opal_accelerator_base_module_get_buffer_id_fn_t)( int dev_id, const void *addr, opal_accelerator_buffer_id_t *buf_id); +/** + * Wait for the completion of all operations inserted into the stream. + * + * @param[IN] stram The stream to wait for. + * + * @return OPAL_SUCCESS or error status on failure + */ +typedef int (*opal_accelerator_base_module_wait_stream_fn_t)(opal_accelerator_stream_t *stream); + +typedef int (*opal_accelerator_base_module_get_num_devices_fn_t)(int *num_devices); + +typedef int (*opal_accelerator_base_module_get_mem_bw_fn_t)(int num_devices, float *bw); + /* * the standard public API data structure */ typedef struct { + /* default stream pointer */ + opal_accelerator_base_get_default_stream_fn_t get_default_stream; /* accelerator function table */ opal_accelerator_base_module_check_addr_fn_t check_addr; @@ -408,10 +502,13 @@ typedef struct { opal_accelerator_base_module_memcpy_async_fn_t mem_copy_async; opal_accelerator_base_module_memcpy_fn_t mem_copy; + opal_accelerator_base_module_memmove_async_fn_t mem_move_async; opal_accelerator_base_module_memmove_fn_t mem_move; opal_accelerator_base_module_mem_alloc_fn_t mem_alloc; opal_accelerator_base_module_mem_release_fn_t mem_release; + opal_accelerator_base_module_mem_alloc_stream_fn_t mem_alloc_stream; + opal_accelerator_base_module_mem_release_stream_fn_t mem_release_stream; opal_accelerator_base_module_get_address_range_fn_t get_address_range; opal_accelerator_base_module_host_register_fn_t host_register; @@ -422,6 +519,12 @@ typedef struct { opal_accelerator_base_module_device_can_access_peer_fn_t device_can_access_peer; opal_accelerator_base_module_get_buffer_id_fn_t get_buffer_id; + + opal_accelerator_base_module_wait_stream_fn_t wait_stream; + + opal_accelerator_base_module_get_num_devices_fn_t num_devices; + + opal_accelerator_base_module_get_mem_bw_fn_t get_mem_bw; } opal_accelerator_base_module_t; /** diff --git a/opal/mca/accelerator/base/accelerator_base_frame.c b/opal/mca/accelerator/base/accelerator_base_frame.c index fcaf86be94e..0721772bf1e 100644 --- a/opal/mca/accelerator/base/accelerator_base_frame.c +++ b/opal/mca/accelerator/base/accelerator_base_frame.c @@ -57,6 +57,12 @@ OBJ_CLASS_INSTANCE( NULL, NULL); +OBJ_CLASS_INSTANCE( + opal_accelerator_mempool_t, + opal_object_t, + NULL, + NULL); + MCA_BASE_FRAMEWORK_DECLARE(opal, accelerator, "OPAL Accelerator Framework", opal_accelerator_base_frame_register, opal_accelerator_base_frame_open, opal_accelerator_base_frame_close, mca_accelerator_base_static_components, diff --git a/opal/mca/accelerator/cuda/Makefile.am b/opal/mca/accelerator/cuda/Makefile.am index 5646890bab3..6f19b62cb63 100644 --- a/opal/mca/accelerator/cuda/Makefile.am +++ b/opal/mca/accelerator/cuda/Makefile.am @@ -34,11 +34,13 @@ mcacomponentdir = $(opallibdir) mcacomponent_LTLIBRARIES = $(component_install) mca_accelerator_cuda_la_SOURCES = $(sources) -mca_accelerator_cuda_la_LDFLAGS = -module -avoid-version +mca_accelerator_cuda_la_LDFLAGS = -module -avoid-version \ + $(accelerator_cuda_LDFLAGS) $(accelerator_cudart_LDFLAGS) mca_accelerator_cuda_la_LIBADD = $(top_builddir)/opal/lib@OPAL_LIB_NAME@.la \ - $(accelerator_cuda_LIBS) + $(accelerator_cuda_LIBS) $(accelerator_cudart_LIBS) noinst_LTLIBRARIES = $(component_noinst) libmca_accelerator_cuda_la_SOURCES =$(sources) -libmca_accelerator_cuda_la_LDFLAGS = -module -avoid-version -libmca_accelerator_cuda_la_LIBADD = $(accelerator_cuda_LIBS) +libmca_accelerator_cuda_la_LDFLAGS = -module -avoid-version \ + $(accelerator_cuda_LDFLAGS) $(accelerator_cudart_LDFLAGS) +libmca_accelerator_cuda_la_LIBADD = $(accelerator_cuda_LIBS) $(accelerator_cudart_LIBS) diff --git a/opal/mca/accelerator/cuda/accelerator_cuda.c b/opal/mca/accelerator/cuda/accelerator_cuda.c index 49d181a0b00..3ff45de0efe 100644 --- a/opal/mca/accelerator/cuda/accelerator_cuda.c +++ b/opal/mca/accelerator/cuda/accelerator_cuda.c @@ -16,6 +16,7 @@ #include "opal_config.h" #include +#include #include "accelerator_cuda.h" #include "opal/mca/accelerator/base/base.h" @@ -23,6 +24,7 @@ #include "opal/util/show_help.h" #include "opal/util/proc.h" /* Accelerator API's */ +static int accelerator_cuda_get_default_stream(int dev_id, opal_accelerator_stream_t **stream); static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *flags); static int accelerator_cuda_create_stream(int dev_id, opal_accelerator_stream_t **stream); @@ -34,10 +36,13 @@ static int accelerator_cuda_memcpy_async(int dest_dev_id, int src_dev_id, void * opal_accelerator_stream_t *stream, opal_accelerator_transfer_type_t type); static int accelerator_cuda_memcpy(int dest_dev_id, int src_dev_id, void *dest, const void *src, size_t size, opal_accelerator_transfer_type_t type); +static int accelerator_cuda_memmove_async(int dest_dev_id, int src_dev_id, void *dest, const void *src, size_t size, opal_accelerator_stream_t *stream, opal_accelerator_transfer_type_t type); static int accelerator_cuda_memmove(int dest_dev_id, int src_dev_id, void *dest, const void *src, size_t size, opal_accelerator_transfer_type_t type); static int accelerator_cuda_mem_alloc(int dev_id, void **ptr, size_t size); static int accelerator_cuda_mem_release(int dev_id, void *ptr); +static int accelerator_cuda_mem_alloc_stream(int dev_id, void **ptr, size_t size, opal_accelerator_stream_t *stream); +static int accelerator_cuda_mem_release_stream(int dev_id, void *ptr, opal_accelerator_stream_t *stream); static int accelerator_cuda_get_address_range(int dev_id, const void *ptr, void **base, size_t *size); @@ -50,8 +55,16 @@ static int accelerator_cuda_device_can_access_peer( int *access, int dev1, int d static int accelerator_cuda_get_buffer_id(int dev_id, const void *addr, opal_accelerator_buffer_id_t *buf_id); +static int accelerator_cuda_wait_stream(opal_accelerator_stream_t *stream); + +static int accelerator_cuda_get_num_devices(int *num_devices); + +static int accelerator_cuda_get_mem_bw(int device, float *bw); + opal_accelerator_base_module_t opal_accelerator_cuda_module = { + accelerator_cuda_get_default_stream, + accelerator_cuda_check_addr, accelerator_cuda_create_stream, @@ -62,9 +75,12 @@ opal_accelerator_base_module_t opal_accelerator_cuda_module = accelerator_cuda_memcpy_async, accelerator_cuda_memcpy, + accelerator_cuda_memmove_async, accelerator_cuda_memmove, accelerator_cuda_mem_alloc, accelerator_cuda_mem_release, + accelerator_cuda_mem_alloc_stream, + accelerator_cuda_mem_release_stream, accelerator_cuda_get_address_range, accelerator_cuda_host_register, @@ -74,9 +90,31 @@ opal_accelerator_base_module_t opal_accelerator_cuda_module = accelerator_cuda_get_device_pci_attr, accelerator_cuda_device_can_access_peer, - accelerator_cuda_get_buffer_id + accelerator_cuda_get_buffer_id, + + accelerator_cuda_wait_stream, + accelerator_cuda_get_num_devices, + accelerator_cuda_get_mem_bw }; +static int accelerator_cuda_get_device_id(CUcontext mem_ctx) { + /* query the device from the context */ + int dev_id = -1; + CUdevice ptr_dev; + cuCtxPushCurrent(mem_ctx); + cuCtxGetDevice(&ptr_dev); + for (int i = 0; i < opal_accelerator_cuda_num_devices; ++i) { + CUdevice dev; + cuDeviceGet(&dev, i); + if (dev == ptr_dev) { + dev_id = i; + break; + } + } + cuCtxPopCurrent(&mem_ctx); + return dev_id; +} + static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *flags) { CUresult result; @@ -125,6 +163,9 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t * } else if (0 == mem_type) { /* This can happen when CUDA is initialized but dbuf is not valid CUDA pointer */ return 0; + } else { + /* query the device from the context */ + *dev_id = accelerator_cuda_get_device_id(mem_ctx); } /* Must be a device pointer */ assert(CU_MEMORYTYPE_DEVICE == mem_type); @@ -140,6 +181,10 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t * } else if (CU_MEMORYTYPE_HOST == mem_type) { /* Host memory, nothing to do here */ return 0; + } else { + result = cuPointerGetAttribute(&mem_ctx, CU_POINTER_ATTRIBUTE_CONTEXT, dbuf); + /* query the device from the context */ + *dev_id = accelerator_cuda_get_device_id(mem_ctx); } /* Must be a device pointer */ assert(CU_MEMORYTYPE_DEVICE == mem_type); @@ -187,7 +232,7 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t * } } - /* WORKAROUND - They are times when the above code determines a pice of memory + /* WORKAROUND - There are times when the above code determines a pice of memory * is GPU memory, but it actually is not. That has been seen on multi-GPU systems * with 6 or 8 GPUs on them. Therefore, we will do this extra check. Note if we * made it this far, then the assumption at this point is we have GPU memory. @@ -211,6 +256,16 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t * return 1; } +static int accelerator_cuda_get_default_stream(int dev_id, opal_accelerator_stream_t **stream) +{ + int delayed_init = opal_accelerator_cuda_delayed_init(); + if (OPAL_UNLIKELY(0 != delayed_init)) { + return delayed_init; + } + *stream = &opal_accelerator_cuda_default_stream.base; + return OPAL_SUCCESS; +} + static int accelerator_cuda_create_stream(int dev_id, opal_accelerator_stream_t **stream) { CUresult result; @@ -391,6 +446,8 @@ static int accelerator_cuda_memcpy(int dest_dev_id, int src_dev_id, void *dest, return OPAL_ERR_BAD_PARAM; } +#if 0 + /* Async copy then synchronize is the default behavior as some applications * cannot utilize synchronous copies. In addition, host memory does not need * to be page-locked if an Async memory copy is done (It just makes it synchronous @@ -399,26 +456,28 @@ static int accelerator_cuda_memcpy(int dest_dev_id, int src_dev_id, void *dest, * Additionally, cuMemcpy is not necessarily always synchronous. See: * https://docs.nvidia.com/cuda/cuda-driver-api/api-sync-behavior.html * TODO: Add optimizations for type field */ - result = cuMemcpyAsync((CUdeviceptr) dest, (CUdeviceptr) src, size, opal_accelerator_cuda_memcpy_stream); + result = cuMemcpyAsync((CUdeviceptr) dest, (CUdeviceptr) src, size, *(CUstream*)opal_accelerator_cuda_memcpy_stream.base.stream); if (OPAL_UNLIKELY(CUDA_SUCCESS != result)) { opal_show_help("help-accelerator-cuda.txt", "cuMemcpyAsync failed", true, dest, src, size, result); return OPAL_ERROR; } - result = cuStreamSynchronize(opal_accelerator_cuda_memcpy_stream); + result = cuStreamSynchronize(*(CUstream*)opal_accelerator_cuda_memcpy_stream.base.stream); +#endif //0 + result = cuMemcpy((CUdeviceptr) dest, (CUdeviceptr) src, size); if (OPAL_UNLIKELY(CUDA_SUCCESS != result)) { - opal_show_help("help-accelerator-cuda.txt", "cuStreamSynchronize failed", true, + opal_show_help("help-accelerator-cuda.txt", "cuMemcpy failed", true, OPAL_PROC_MY_HOSTNAME, result); return OPAL_ERROR; } return OPAL_SUCCESS; } -static int accelerator_cuda_memmove(int dest_dev_id, int src_dev_id, void *dest, const void *src, size_t size, - opal_accelerator_transfer_type_t type) +static int accelerator_cuda_memmove_async(int dest_dev_id, int src_dev_id, void *dest, const void *src, size_t size, opal_accelerator_stream_t *stream, opal_accelerator_transfer_type_t type) { CUdeviceptr tmp; CUresult result; + void *ptr; int delayed_init = opal_accelerator_cuda_delayed_init(); if (OPAL_UNLIKELY(0 != delayed_init)) { @@ -429,29 +488,42 @@ static int accelerator_cuda_memmove(int dest_dev_id, int src_dev_id, void *dest, return OPAL_ERR_BAD_PARAM; } - result = cuMemAlloc(&tmp, size); - if (OPAL_UNLIKELY(CUDA_SUCCESS != result)) { + result = accelerator_cuda_mem_alloc_stream(src_dev_id, &ptr, size, stream); + if (OPAL_UNLIKELY(OPAL_SUCCESS != result)) { return OPAL_ERROR; } - result = cuMemcpyAsync(tmp, (CUdeviceptr) src, size, opal_accelerator_cuda_memcpy_stream); + tmp = (CUdeviceptr)ptr; + result = cuMemcpyAsync(tmp, (CUdeviceptr) src, size, *(CUstream*)stream->stream); if (OPAL_UNLIKELY(CUDA_SUCCESS != result)) { opal_show_help("help-accelerator-cuda.txt", "cuMemcpyAsync failed", true, tmp, src, size, result); return OPAL_ERROR; } - result = cuMemcpyAsync((CUdeviceptr) dest, tmp, size, opal_accelerator_cuda_memcpy_stream); + result = cuMemcpyAsync((CUdeviceptr) dest, tmp, size, *(CUstream*)stream->stream); if (OPAL_UNLIKELY(CUDA_SUCCESS != result)) { opal_show_help("help-accelerator-cuda.txt", "cuMemcpyAsync failed", true, dest, tmp, size, result); return OPAL_ERROR; } - result = cuStreamSynchronize(opal_accelerator_cuda_memcpy_stream); + return accelerator_cuda_mem_release_stream(src_dev_id, ptr, stream); +} + +static int accelerator_cuda_memmove(int dest_dev_id, int src_dev_id, void *dest, const void *src, size_t size, + opal_accelerator_transfer_type_t type) +{ + int ret; + CUresult result; + + ret = accelerator_cuda_memmove_async(dest_dev_id, src_dev_id, dest, src, size, &opal_accelerator_cuda_memcpy_stream.base, type); + if (OPAL_SUCCESS != ret) { + return OPAL_ERROR; + } + result = accelerator_cuda_wait_stream(&opal_accelerator_cuda_memcpy_stream.base); if (OPAL_UNLIKELY(CUDA_SUCCESS != result)) { opal_show_help("help-accelerator-cuda.txt", "cuStreamSynchronize failed", true, OPAL_PROC_MY_HOSTNAME, result); return OPAL_ERROR; } - cuMemFree(tmp); return OPAL_SUCCESS; } @@ -468,14 +540,47 @@ static int accelerator_cuda_mem_alloc(int dev_id, void **ptr, size_t size) return OPAL_ERR_BAD_PARAM; } - if (size > 0) { - result = cuMemAlloc((CUdeviceptr *) ptr, size); - if (OPAL_UNLIKELY(CUDA_SUCCESS != result)) { - opal_show_help("help-accelerator-cuda.txt", "cuMemAlloc failed", true, - OPAL_PROC_MY_HOSTNAME, result); - return OPAL_ERROR; +#if 0 + /* prefer managed memory */ + result = cudaMallocManaged(ptr, size, cudaMemAttachGlobal); + if (cudaSuccess == result) { + return OPAL_SUCCESS; + } +#endif // 0 + + /* fall-back to discrete memory */ + +#if CUDA_VERSION >= 11020 && 0 + /* Try to allocate the memory from a memory pool, if available */ + /* get the default pool */ + cudaMemPool_t mpool; + result = cudaDeviceGetDefaultMemPool(&mpool, dev_id); + if (cudaSuccess == result) { + result = cudaMallocFromPoolAsync(ptr, size, mpool, opal_accelerator_cuda_alloc_stream); + if (cudaSuccess == result) { + /* this is a blocking function, so wait for the allocation to happen */ + result = cuStreamSynchronize(opal_accelerator_cuda_alloc_stream); + if (cudaSuccess == result) { + printf("CUDA from mempool %p\n", *ptr); + return OPAL_SUCCESS; + } } } + if (cudaErrorNotSupported != result) { + opal_show_help("help-accelerator-cuda.txt", "cudaMallocFromPoolAsync failed", true, + OPAL_PROC_MY_HOSTNAME, result); + return OPAL_ERROR; + } + /* fall-back to regular allocation */ +#endif // CUDA_VERSION >= 11020 + + result = cuMemAlloc((CUdeviceptr *) ptr, size); + if (OPAL_UNLIKELY(CUDA_SUCCESS != result)) { + opal_show_help("help-accelerator-cuda.txt", "cuMemAlloc failed", true, + OPAL_PROC_MY_HOSTNAME, result); + return OPAL_ERROR; + } + //printf("CUDA from cuMemAlloc %p\n", *ptr); return 0; } @@ -673,3 +778,117 @@ static int accelerator_cuda_get_buffer_id(int dev_id, const void *addr, opal_acc } return OPAL_SUCCESS; } + + +static int accelerator_cuda_mem_alloc_stream( + int dev_id, + void **addr, + size_t size, + opal_accelerator_stream_t *stream) +{ + + int delayed_init = opal_accelerator_cuda_delayed_init(); + if (OPAL_UNLIKELY(0 != delayed_init)) { + return delayed_init; + } + +#if CUDA_VERSION >= 11020 + cudaError_t result; + + if (NULL == stream || NULL == addr || 0 == size) { + return OPAL_ERR_BAD_PARAM; + } + + /* Try to allocate the memory from a memory pool, if available */ + /* get the default pool */ + cudaMemPool_t mpool; + result = cudaDeviceGetDefaultMemPool(&mpool, dev_id); + if (cudaSuccess == result) { + result = cudaMallocFromPoolAsync(addr, size, mpool, *(cudaStream_t*)stream->stream); + if (cudaSuccess == result) { + return OPAL_SUCCESS; + } + } + if (cudaErrorNotSupported != result) { + opal_show_help("help-accelerator-cuda.txt", "cudaMallocFromPoolAsync failed", true, + OPAL_PROC_MY_HOSTNAME, result); + return OPAL_ERROR; + } + /* fall-back to regular stream allocation */ + + result = cudaMallocAsync(addr, size, *(cudaStream_t*)stream->stream); + if (OPAL_UNLIKELY(cudaSuccess != result)) { + opal_show_help("help-accelerator-cuda.txt", "cuMemAlloc failed", true, + OPAL_PROC_MY_HOSTNAME, result); + return OPAL_ERROR; + } + return OPAL_SUCCESS; +#else + return accelerator_cuda_mem_alloc(dev_id, addr, size); +#endif // CUDA_VERSION >= 11020 +} + + +static int accelerator_cuda_mem_release_stream( + int dev_id, + void *addr, + opal_accelerator_stream_t *stream) +{ +#if CUDA_VERSION >= 11020 + cudaError_t result; + + if (NULL == stream || NULL == addr) { + return OPAL_ERR_BAD_PARAM; + } + + result = cudaFreeAsync(addr, *(cudaStream_t*)stream->stream); + if (OPAL_UNLIKELY(cudaSuccess != result)) { + opal_show_help("help-accelerator-cuda.txt", "cuMemAlloc failed", true, + OPAL_PROC_MY_HOSTNAME, result); + return OPAL_ERROR; + } + return OPAL_SUCCESS; +#else + /* wait for everything on the device to complete */ + accelerator_cuda_wait_stream(stream); + return accelerator_cuda_mem_release(dev_id, addr); +#endif // CUDA_VERSION >= 11020 +} + + +static int accelerator_cuda_wait_stream(opal_accelerator_stream_t *stream) +{ + CUresult result; + result = cuStreamSynchronize(*(CUstream*)stream->stream); + if (OPAL_UNLIKELY(CUDA_SUCCESS != result)) { + opal_show_help("help-accelerator-cuda.txt", "cuStreamSynchronize failed", true, + OPAL_PROC_MY_HOSTNAME, result); + return OPAL_ERROR; + } + return OPAL_SUCCESS; +} + + +static int accelerator_cuda_get_num_devices(int *num_devices) +{ + + int delayed_init = opal_accelerator_cuda_delayed_init(); + if (OPAL_UNLIKELY(0 != delayed_init)) { + return delayed_init; + } + + *num_devices = opal_accelerator_cuda_num_devices; + return OPAL_SUCCESS; +} + +static int accelerator_cuda_get_mem_bw(int device, float *bw) +{ + int delayed_init = opal_accelerator_cuda_delayed_init(); + if (OPAL_UNLIKELY(0 != delayed_init)) { + return delayed_init; + } + assert(opal_accelerator_cuda_mem_bw != NULL); + + *bw = opal_accelerator_cuda_mem_bw[device]; + return OPAL_SUCCESS; +} diff --git a/opal/mca/accelerator/cuda/accelerator_cuda.h b/opal/mca/accelerator/cuda/accelerator_cuda.h index 694a4192231..3ef66820d72 100644 --- a/opal/mca/accelerator/cuda/accelerator_cuda.h +++ b/opal/mca/accelerator/cuda/accelerator_cuda.h @@ -15,6 +15,7 @@ #include "opal_config.h" #include +#include #include "opal/mca/accelerator/accelerator.h" #include "opal/mca/threads/mutex.h" @@ -38,13 +39,19 @@ typedef struct opal_accelerator_cuda_event_t opal_accelerator_cuda_event_t; OBJ_CLASS_DECLARATION(opal_accelerator_cuda_event_t); /* Declare extern variables, defined in accelerator_cuda_component.c */ -OPAL_DECLSPEC extern CUstream opal_accelerator_cuda_memcpy_stream; +OPAL_DECLSPEC extern opal_accelerator_cuda_stream_t opal_accelerator_cuda_memcpy_stream; +OPAL_DECLSPEC extern CUstream opal_accelerator_cuda_alloc_stream; +OPAL_DECLSPEC extern opal_accelerator_cuda_stream_t opal_accelerator_cuda_default_stream; OPAL_DECLSPEC extern opal_mutex_t opal_accelerator_cuda_stream_lock; OPAL_DECLSPEC extern opal_accelerator_cuda_component_t mca_accelerator_cuda_component; OPAL_DECLSPEC extern opal_accelerator_base_module_t opal_accelerator_cuda_module; +OPAL_DECLSPEC extern int opal_accelerator_cuda_num_devices; + +OPAL_DECLSPEC extern float *opal_accelerator_cuda_mem_bw; + OPAL_DECLSPEC extern int opal_accelerator_cuda_delayed_init(void); END_C_DECLS diff --git a/opal/mca/accelerator/cuda/accelerator_cuda_component.c b/opal/mca/accelerator/cuda/accelerator_cuda_component.c index d880ee5dca8..04006f85dd4 100644 --- a/opal/mca/accelerator/cuda/accelerator_cuda_component.c +++ b/opal/mca/accelerator/cuda/accelerator_cuda_component.c @@ -34,13 +34,19 @@ #include "opal/sys/atomic.h" /* Define global variables, used in accelerator_cuda.c */ -CUstream opal_accelerator_cuda_memcpy_stream = NULL; +opal_accelerator_cuda_stream_t opal_accelerator_cuda_memcpy_stream = {0}; +CUstream opal_accelerator_cuda_alloc_stream = NULL; +opal_accelerator_cuda_stream_t opal_accelerator_cuda_default_stream = {0}; opal_mutex_t opal_accelerator_cuda_stream_lock = {0}; +int opal_accelerator_cuda_num_devices = 0; /* Initialization lock for delayed cuda initialization */ static opal_mutex_t accelerator_cuda_init_lock; static bool accelerator_cuda_init_complete = false; +float *opal_accelerator_cuda_mem_bw = NULL; + + #define STRINGIFY2(x) #x #define STRINGIFY(x) STRINGIFY2(x) @@ -121,7 +127,8 @@ static int accelerator_cuda_component_register(void) int opal_accelerator_cuda_delayed_init() { - int result = OPAL_SUCCESS; + CUresult result = OPAL_SUCCESS; + int prio_lo, prio_hi; CUcontext cuContext; /* Double checked locking to avoid having to @@ -137,6 +144,8 @@ int opal_accelerator_cuda_delayed_init() goto out; } + cuDeviceGetCount(&opal_accelerator_cuda_num_devices); + /* Check to see if this process is running in a CUDA context. If * so, all is good. If not, then disable registration of memory. */ result = cuCtxGetCurrent(&cuContext); @@ -145,31 +154,111 @@ int opal_accelerator_cuda_delayed_init() goto out; } else if ((CUDA_SUCCESS == result) && (NULL == cuContext)) { opal_output_verbose(20, opal_accelerator_base_framework.framework_output, "CUDA: cuCtxGetCurrent returned NULL context"); - result = OPAL_ERROR; - goto out; + /* create a context for each device */ + for (int i = 0; i < opal_accelerator_cuda_num_devices; ++i) { + CUdevice dev; + result = cuDeviceGet(&dev, i); + if (CUDA_SUCCESS != result) { + opal_output_verbose(20, opal_accelerator_base_framework.framework_output, + "CUDA: cuDeviceGet failed"); + goto out; + } + result = cuDevicePrimaryCtxRetain(&cuContext, dev); + if (CUDA_SUCCESS != result) { + opal_output_verbose(20, opal_accelerator_base_framework.framework_output, + "CUDA: cuDevicePrimaryCtxRetain failed"); + goto out; + } + if (0 == i) { + result = cuCtxPushCurrent(cuContext); + if (CUDA_SUCCESS != result) { + opal_output_verbose(20, opal_accelerator_base_framework.framework_output, + "CUDA: cuCtxPushCurrent failed"); + goto out; + } + } + } + } else { opal_output_verbose(20, opal_accelerator_base_framework.framework_output, "CUDA: cuCtxGetCurrent succeeded"); } + + /* Create stream for use in cuMemcpyAsync synchronous copies */ + CUstream memcpy_stream; + result = cuStreamCreate(&memcpy_stream, 0); + if (OPAL_UNLIKELY(result != CUDA_SUCCESS)) { + opal_show_help("help-accelerator-cuda.txt", "cuStreamCreate failed", true, + OPAL_PROC_MY_HOSTNAME, result); + goto out; + } + OBJ_CONSTRUCT(&opal_accelerator_cuda_memcpy_stream, opal_accelerator_cuda_stream_t); + opal_accelerator_cuda_memcpy_stream.base.stream = malloc(sizeof(CUstream)); + *(CUstream*)opal_accelerator_cuda_memcpy_stream.base.stream = memcpy_stream; + /* Create stream for use in cuMemcpyAsync synchronous copies */ - result = cuStreamCreate(&opal_accelerator_cuda_memcpy_stream, 0); + result = cuStreamCreate(&opal_accelerator_cuda_alloc_stream, 0); if (OPAL_UNLIKELY(result != CUDA_SUCCESS)) { opal_show_help("help-accelerator-cuda.txt", "cuStreamCreate failed", true, OPAL_PROC_MY_HOSTNAME, result); goto out; } + /* Create a default stream to be used by various components. + * We try to create a high-priority stream and fall back to a regular stream. + */ + CUstream *default_stream = malloc(sizeof(CUstream)); + result = cuCtxGetStreamPriorityRange(&prio_lo, &prio_hi); + if (CUDA_SUCCESS != result) { + result = cuStreamCreateWithPriority(default_stream, + CU_STREAM_NON_BLOCKING, prio_hi); + } else { + result = cuStreamCreate(default_stream, 0); + } + if (OPAL_UNLIKELY(result != CUDA_SUCCESS)) { + opal_show_help("help-accelerator-cuda.txt", "cuStreamCreate failed", true, + OPAL_PROC_MY_HOSTNAME, result); + goto out; + } + OBJ_CONSTRUCT(&opal_accelerator_cuda_default_stream, opal_accelerator_cuda_stream_t); + opal_accelerator_cuda_default_stream.base.stream = default_stream; + result = cuMemHostRegister(&checkmem, sizeof(int), 0); if (result != CUDA_SUCCESS) { /* If registering the memory fails, print a message and continue. * This is not a fatal error. */ opal_show_help("help-accelerator-cuda.txt", "cuMemHostRegister during init failed", true, &checkmem, sizeof(int), OPAL_PROC_MY_HOSTNAME, result, "checkmem"); - } else { opal_output_verbose(20, opal_accelerator_base_framework.framework_output, "CUDA: cuMemHostRegister OK on test region"); } + + opal_accelerator_cuda_mem_bw = malloc(sizeof(float)*opal_accelerator_cuda_num_devices); + for (int i = 0; i < opal_accelerator_cuda_num_devices; ++i) { + CUdevice dev; + result = cuDeviceGet(&dev, i); + if (CUDA_SUCCESS != result) { + opal_output_verbose(20, opal_accelerator_base_framework.framework_output, + "CUDA: cuDeviceGet failed"); + goto out; + } + int mem_clock_rate; // kHz + result = cuDeviceGetAttribute(&mem_clock_rate, + CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, + dev); + int bus_width; // bit + result = cuDeviceGetAttribute(&bus_width, + CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, + dev); + /* bw = clock_rate * bus width * 2bit multiplier + * See https://forums.developer.nvidia.com/t/memory-clock-rate/107940 + */ + float bw = ((float)mem_clock_rate*(float)bus_width*2.0) / 1024 / 1024 / 8; + //printf("clock rate: %d kHz, bus width: %d bit, bandwidth: %f GB/s\n", mem_clock_rate, bus_width, bw); + opal_accelerator_cuda_mem_bw[i] = bw; + } + result = OPAL_SUCCESS; opal_atomic_wmb(); accelerator_cuda_init_complete = true; @@ -182,6 +271,9 @@ static opal_accelerator_base_module_t* accelerator_cuda_init(void) { OBJ_CONSTRUCT(&opal_accelerator_cuda_stream_lock, opal_mutex_t); OBJ_CONSTRUCT(&accelerator_cuda_init_lock, opal_mutex_t); + OBJ_CONSTRUCT(&opal_accelerator_cuda_default_stream, opal_accelerator_stream_t); + OBJ_CONSTRUCT(&opal_accelerator_cuda_memcpy_stream, opal_accelerator_stream_t); + /* First check if the support is enabled. In the case that the user has * turned it off, we do not need to continue with any CUDA specific * initialization. Do this after MCA parameter registration. */ @@ -213,9 +305,19 @@ static void accelerator_cuda_finalize(opal_accelerator_base_module_t* module) if (CUDA_SUCCESS != result) { ctx_ok = 0; } - if ((NULL != opal_accelerator_cuda_memcpy_stream) && ctx_ok) { - cuStreamDestroy(opal_accelerator_cuda_memcpy_stream); + if ((NULL != opal_accelerator_cuda_memcpy_stream.base.stream) && ctx_ok) { + OBJ_DESTRUCT(&opal_accelerator_cuda_memcpy_stream); + } + if ((NULL != opal_accelerator_cuda_alloc_stream) && ctx_ok) { + cuStreamDestroy(opal_accelerator_cuda_alloc_stream); } + if ((NULL != opal_accelerator_cuda_default_stream.base.stream) && ctx_ok) { + OBJ_DESTRUCT(&opal_accelerator_cuda_default_stream); + } + + free(opal_accelerator_cuda_mem_bw); + opal_accelerator_cuda_mem_bw = NULL; + OBJ_DESTRUCT(&opal_accelerator_cuda_stream_lock); OBJ_DESTRUCT(&accelerator_cuda_init_lock); diff --git a/opal/mca/accelerator/cuda/configure.m4 b/opal/mca/accelerator/cuda/configure.m4 index aa67623c8b2..2792d52c840 100644 --- a/opal/mca/accelerator/cuda/configure.m4 +++ b/opal/mca/accelerator/cuda/configure.m4 @@ -24,6 +24,7 @@ AC_DEFUN([MCA_opal_accelerator_cuda_CONFIG],[ AC_CONFIG_FILES([opal/mca/accelerator/cuda/Makefile]) OPAL_CHECK_CUDA([accelerator_cuda]) + OPAL_CHECK_CUDART([accelerator_cudart]) AS_IF([test "x$CUDA_SUPPORT" = "x1"], [$1], @@ -33,4 +34,8 @@ AC_DEFUN([MCA_opal_accelerator_cuda_CONFIG],[ AC_SUBST([accelerator_cuda_LDFLAGS]) AC_SUBST([accelerator_cuda_LIBS]) + AC_SUBST([accelerator_cudart_CPPFLAGS]) + AC_SUBST([accelerator_cudart_LDFLAGS]) + AC_SUBST([accelerator_cudart_LIBS]) + ])dnl diff --git a/opal/mca/accelerator/cuda/help-accelerator-cuda.txt b/opal/mca/accelerator/cuda/help-accelerator-cuda.txt index 2cf7a14bf5d..750344e3a78 100644 --- a/opal/mca/accelerator/cuda/help-accelerator-cuda.txt +++ b/opal/mca/accelerator/cuda/help-accelerator-cuda.txt @@ -262,3 +262,10 @@ Check the cuda.h file for what the return value means. A call to allocate memory within the CUDA support failed. This is an unrecoverable error and will cause the program to abort. Hostname: %s +# +[cudaMemPoolCreate failed] +The call to cudaMemPoolCreate failed. This is highly unusual and should +not happen. Please report this error to the Open MPI developers. + Hostname: %s + cudaMemPoolCreate return value: %d +Check the cuda_runtime_api.h file for what the return value means. \ No newline at end of file diff --git a/opal/mca/accelerator/null/accelerator_null_component.c b/opal/mca/accelerator/null/accelerator_null_component.c index 4a0d307497b..24f9e04419e 100644 --- a/opal/mca/accelerator/null/accelerator_null_component.c +++ b/opal/mca/accelerator/null/accelerator_null_component.c @@ -27,6 +27,8 @@ const char *opal_accelerator_null_component_version_string = "OPAL null accelerator MCA component version " OPAL_VERSION; +static opal_accelerator_stream_t default_stream; + /* * Component API functions */ @@ -37,8 +39,10 @@ static opal_accelerator_base_module_t* accelerator_null_init(void); static void accelerator_null_finalize(opal_accelerator_base_module_t* module); /* Accelerator API's */ +static int accelerator_null_get_default_stream(int dev_id, opal_accelerator_stream_t **stream); static int accelerator_null_check_addr(const void *addr, int *dev_id, uint64_t *flags); +static int accelerator_null_get_default_stream(int dev_id, opal_accelerator_stream_t **stream); static int accelerator_null_create_stream(int dev_id, opal_accelerator_stream_t **stream); static int accelerator_null_create_event(int dev_id, opal_accelerator_event_t **event); static int accelerator_null_record_event(int dev_id, opal_accelerator_event_t *event, opal_accelerator_stream_t *stream); @@ -48,11 +52,15 @@ static int accelerator_null_memcpy_async(int dest_dev_id, int src_dev_id, void * opal_accelerator_stream_t *stream, opal_accelerator_transfer_type_t type); static int accelerator_null_memcpy(int dest_dev_id, int src_dev_id, void *dest, const void *src, size_t size, opal_accelerator_transfer_type_t type); +static int accelerator_null_memmove_async(int dest_dev_id, int src_dev_id, void *dest, const void *src, size_t size, + opal_accelerator_stream_t *stream, opal_accelerator_transfer_type_t type); static int accelerator_null_memmove(int dest_dev_id, int src_dev_id, void *dest, const void *src, size_t size, opal_accelerator_transfer_type_t type); static int accelerator_null_mem_alloc(int dev_id, void **ptr, size_t size); static int accelerator_null_mem_release(int dev_id, void *ptr); +static int accelerator_null_mem_alloc_stream(int dev_id, void **ptr, size_t size, opal_accelerator_stream_t* stream); +static int accelerator_null_mem_release_stream(int dev_id, void *ptr, opal_accelerator_stream_t *stream); static int accelerator_null_get_address_range(int dev_id, const void *ptr, void **base, size_t *size); static int accelerator_null_host_register(int dev_id, void *ptr, size_t size); @@ -64,6 +72,12 @@ static int accelerator_null_device_can_access_peer(int *access, int dev1, int de static int accelerator_null_get_buffer_id(int dev_id, const void *addr, opal_accelerator_buffer_id_t *buf_id); +static int accelerator_null_wait_stream(opal_accelerator_stream_t *stream); + +static int accelerator_null_get_num_devices(int *num_devices); + +static int accelerator_null_get_mem_bw(int device, float *bw); + /* * Instantiate the public struct with all of our public information * and pointers to our public functions in it @@ -104,6 +118,8 @@ opal_accelerator_null_component_t mca_accelerator_null_component = {{ opal_accelerator_base_module_t opal_accelerator_null_module = { + accelerator_null_get_default_stream, + accelerator_null_check_addr, accelerator_null_create_stream, @@ -114,9 +130,12 @@ opal_accelerator_base_module_t opal_accelerator_null_module = accelerator_null_memcpy_async, accelerator_null_memcpy, + accelerator_null_memmove_async, accelerator_null_memmove, accelerator_null_mem_alloc, accelerator_null_mem_release, + accelerator_null_mem_alloc_stream, + accelerator_null_mem_release_stream, accelerator_null_get_address_range, accelerator_null_host_register, @@ -126,7 +145,11 @@ opal_accelerator_base_module_t opal_accelerator_null_module = accelerator_null_get_device_pci_attr, accelerator_null_device_can_access_peer, - accelerator_null_get_buffer_id + accelerator_null_get_buffer_id, + + accelerator_null_wait_stream, + accelerator_null_get_num_devices, + accelerator_null_get_mem_bw }; static int accelerator_null_open(void) @@ -161,6 +184,11 @@ static int accelerator_null_check_addr(const void *addr, int *dev_id, uint64_t * return 0; } +static int accelerator_null_get_default_stream(int dev_id, opal_accelerator_stream_t **stream) +{ + *stream = &default_stream; + return OPAL_SUCCESS; +} static int accelerator_null_create_stream(int dev_id, opal_accelerator_stream_t **stream) { *stream = OBJ_NEW(opal_accelerator_stream_t); @@ -204,6 +232,12 @@ static int accelerator_null_memmove(int dest_dev_id, int src_dev_id, void *dest, return OPAL_SUCCESS; } +static int accelerator_null_memmove_async(int dest_dev_id, int src_dev_id, void *dest, const void *src, size_t size, + opal_accelerator_stream_t *stream, opal_accelerator_transfer_type_t type) +{ + memmove(dest, src, size); + return OPAL_SUCCESS; +} static int accelerator_null_mem_alloc(int dev_id, void **ptr, size_t size) { *ptr = malloc(size); @@ -216,6 +250,23 @@ static int accelerator_null_mem_release(int dev_id, void *ptr) return OPAL_SUCCESS; } +static int accelerator_null_mem_alloc_stream(int dev_id, void **ptr, size_t size, + opal_accelerator_stream_t *stream) +{ + (void)stream; + *ptr = malloc(size); + return OPAL_SUCCESS; +} + +static int accelerator_null_mem_release_stream(int dev_id, void *ptr, + opal_accelerator_stream_t *stream) +{ + (void)stream; + free(ptr); + return OPAL_SUCCESS; +} + + static int accelerator_null_get_address_range(int dev_id, const void *ptr, void **base, size_t *size) { @@ -251,3 +302,22 @@ static int accelerator_null_get_buffer_id(int dev_id, const void *addr, opal_acc { return OPAL_ERR_NOT_IMPLEMENTED; } + + +static int accelerator_null_wait_stream(opal_accelerator_stream_t *stream) +{ + return OPAL_SUCCESS; +} + +static int accelerator_null_get_num_devices(int *num_devices) +{ + *num_devices = 0; + return OPAL_SUCCESS; +} + + +static int accelerator_null_get_mem_bw(int device, float *bw) +{ + *bw = 1.0; // return something that is not 0 + return OPAL_SUCCESS; +} \ No newline at end of file diff --git a/opal/mca/accelerator/rocm/accelerator_rocm.h b/opal/mca/accelerator/rocm/accelerator_rocm.h index fdc062af612..abd2f8125c3 100644 --- a/opal/mca/accelerator/rocm/accelerator_rocm.h +++ b/opal/mca/accelerator/rocm/accelerator_rocm.h @@ -55,7 +55,7 @@ struct opal_accelerator_rocm_event_t { typedef struct opal_accelerator_rocm_event_t opal_accelerator_rocm_event_t; OBJ_CLASS_DECLARATION(opal_accelerator_rocm_event_t); -OPAL_DECLSPEC extern hipStream_t opal_accelerator_rocm_MemcpyStream; +OPAL_DECLSPEC extern hipStream_t *opal_accelerator_rocm_MemcpyStream; OPAL_DECLSPEC extern int opal_accelerator_rocm_memcpy_async; OPAL_DECLSPEC extern int opal_accelerator_rocm_verbose; OPAL_DECLSPEC extern size_t opal_accelerator_rocm_memcpyH2D_limit; @@ -63,4 +63,9 @@ OPAL_DECLSPEC extern size_t opal_accelerator_rocm_memcpyD2H_limit; OPAL_DECLSPEC extern int opal_accelerator_rocm_lazy_init(void); +OPAL_DECLSPEC extern hipStream_t opal_accelerator_alloc_stream; +OPAL_DECLSPEC extern opal_accelerator_rocm_stream_t opal_accelerator_rocm_default_stream; +OPAL_DECLSPEC extern opal_mutex_t opal_accelerator_rocm_stream_lock; +OPAL_DECLSPEC extern int opal_accelerator_rocm_num_devices; +OPAL_DECLSPEC extern float *opal_accelerator_rocm_mem_bw; #endif diff --git a/opal/mca/accelerator/rocm/accelerator_rocm_component.c b/opal/mca/accelerator/rocm/accelerator_rocm_component.c index 317de021565..86d71c0034b 100644 --- a/opal/mca/accelerator/rocm/accelerator_rocm_component.c +++ b/opal/mca/accelerator/rocm/accelerator_rocm_component.c @@ -20,8 +20,10 @@ #include #include "opal/mca/dl/base/base.h" +#include "opal/mca/accelerator/base/base.h" #include "opal/runtime/opal_params.h" #include "accelerator_rocm.h" +#include "opal/util/proc.h" int opal_accelerator_rocm_memcpy_async = 1; int opal_accelerator_rocm_verbose = 0; @@ -32,7 +34,7 @@ size_t opal_accelerator_rocm_memcpyH2D_limit=1048576; static opal_mutex_t accelerator_rocm_init_lock; static bool accelerator_rocm_init_complete = false; -hipStream_t opal_accelerator_rocm_MemcpyStream = NULL; +hipStream_t *opal_accelerator_rocm_MemcpyStream = NULL; /* * Public string showing the accelerator rocm component version number @@ -40,6 +42,14 @@ hipStream_t opal_accelerator_rocm_MemcpyStream = NULL; const char *opal_accelerator_rocm_component_version_string = "OPAL rocm accelerator MCA component version " OPAL_VERSION; +/* Define global variables, used in accelerator_rocm.c */ +//opal_accelerator_rocm_stream_t opal_accelerator_rocm_memcpy_stream = {0}; +hipStream_t opal_accelerator_rocm_alloc_stream = NULL; +opal_accelerator_rocm_stream_t opal_accelerator_rocm_default_stream = {0}; +opal_mutex_t opal_accelerator_rocm_stream_lock = {0}; +int opal_accelerator_rocm_num_devices = 0; + +float *opal_accelerator_rocm_mem_bw = NULL; #define HIP_CHECK(condition) \ { \ @@ -160,6 +170,7 @@ static int accelerator_rocm_component_register(void) int opal_accelerator_rocm_lazy_init() { + int prio_hi, prio_lo; int err = OPAL_SUCCESS; /* Double checked locking to avoid having to @@ -175,12 +186,62 @@ int opal_accelerator_rocm_lazy_init() goto out; } - err = hipStreamCreate(&opal_accelerator_rocm_MemcpyStream); + hipGetDeviceCount(&opal_accelerator_rocm_num_devices); + + /* Create stream for use in cuMemcpyAsync synchronous copies */ + hipStream_t memcpy_stream; + err = hipStreamCreate(&memcpy_stream); + if (OPAL_UNLIKELY(err != hipSuccess)) { + opal_show_help("help-accelerator-rocm.txt", "hipStreamCreateWithFlags failed", true, + OPAL_PROC_MY_HOSTNAME, err); + goto out; + } + opal_accelerator_rocm_MemcpyStream = malloc(sizeof(hipStream_t)); + *(hipStream_t*)opal_accelerator_rocm_MemcpyStream = memcpy_stream; + + /* Create stream for use in cuMemcpyAsync synchronous copies */ + err = hipStreamCreateWithFlags(&opal_accelerator_rocm_alloc_stream, 0); + if (OPAL_UNLIKELY(err != hipSuccess)) { + opal_show_help("help-accelerator-rocm.txt", "hipStreamCreateWithFlags failed", true, + OPAL_PROC_MY_HOSTNAME, err); + goto out; + } + + /* Create a default stream to be used by various components. + * We try to create a high-priority stream and fall back to a regular stream. + */ + hipStream_t *default_stream = malloc(sizeof(hipStream_t)); + err = hipDeviceGetStreamPriorityRange(&prio_lo, &prio_hi); if (hipSuccess != err) { - opal_output(0, "Could not create hipStream, err=%d %s\n", - err, hipGetErrorString(err)); + err = hipStreamCreateWithPriority(default_stream, + hipStreamNonBlocking, prio_hi); + } else { + err = hipStreamCreateWithFlags(default_stream, 0); + } + if (OPAL_UNLIKELY(err != hipSuccess)) { + opal_show_help("help-accelerator-rocm.txt", "hipStreamCreateWithFlags failed", true, + OPAL_PROC_MY_HOSTNAME, err); goto out; } + OBJ_CONSTRUCT(&opal_accelerator_rocm_default_stream, opal_accelerator_rocm_stream_t); + opal_accelerator_rocm_default_stream.base.stream = default_stream; + + opal_accelerator_rocm_mem_bw = malloc(sizeof(float)*opal_accelerator_rocm_num_devices); + for (int i = 0; i < opal_accelerator_rocm_num_devices; ++i) { + int mem_clock_rate; // kHz + err = hipDeviceGetAttribute(&mem_clock_rate, + hipDeviceAttributeMemoryClockRate, + i); + int bus_width; // bit + err = hipDeviceGetAttribute(&bus_width, + hipDeviceAttributeMemoryBusWidth, + i); + /* bw = clock_rate * bus width * 2bit multiplier + * See https://forums.developer.nvidia.com/t/memory-clock-rate/107940 + */ + float bw = ((float)mem_clock_rate*(float)bus_width*2.0) / 1024 / 1024 / 8; + opal_accelerator_rocm_mem_bw[i] = bw; + } err = OPAL_SUCCESS; opal_atomic_wmb(); @@ -193,7 +254,7 @@ int opal_accelerator_rocm_lazy_init() static opal_accelerator_base_module_t* accelerator_rocm_init(void) { OBJ_CONSTRUCT(&accelerator_rocm_init_lock, opal_mutex_t); - + hipError_t err; if (opal_rocm_runtime_initialized) { @@ -215,14 +276,19 @@ static opal_accelerator_base_module_t* accelerator_rocm_init(void) static void accelerator_rocm_finalize(opal_accelerator_base_module_t* module) { - if (NULL != (void*)opal_accelerator_rocm_MemcpyStream) { - hipError_t err = hipStreamDestroy(opal_accelerator_rocm_MemcpyStream); + if (NULL != opal_accelerator_rocm_MemcpyStream) { + hipError_t err = hipStreamDestroy(*opal_accelerator_rocm_MemcpyStream); if (hipSuccess != err) { opal_output_verbose(10, 0, "hip_dl_finalize: error while destroying the hipStream\n"); } + free(opal_accelerator_rocm_MemcpyStream); opal_accelerator_rocm_MemcpyStream = NULL; } + free(opal_accelerator_rocm_mem_bw); + opal_accelerator_rocm_mem_bw = NULL; + + OBJ_DESTRUCT(&accelerator_rocm_init_lock); return; } diff --git a/opal/mca/accelerator/rocm/accelerator_rocm_module.c b/opal/mca/accelerator/rocm/accelerator_rocm_module.c index d5640db2100..ee31233199f 100644 --- a/opal/mca/accelerator/rocm/accelerator_rocm_module.c +++ b/opal/mca/accelerator/rocm/accelerator_rocm_module.c @@ -13,8 +13,10 @@ #include "opal/mca/accelerator/base/base.h" #include "opal/constants.h" #include "opal/util/output.h" +#include "opal/util/proc.h" /* Accelerator API's */ +static int mca_accelerator_rocm_get_default_stream(int dev_id, opal_accelerator_stream_t **stream); static int mca_accelerator_rocm_check_addr(const void *addr, int *dev_id, uint64_t *flags); static int mca_accelerator_rocm_create_stream(int dev_id, opal_accelerator_stream_t **stream); @@ -26,10 +28,13 @@ static int mca_accelerator_rocm_memcpy_async(int dest_dev_id, int src_dev_id, vo opal_accelerator_stream_t *stream, opal_accelerator_transfer_type_t type); static int mca_accelerator_rocm_memcpy(int dest_dev_id, int src_dev_id, void *dest, const void *src, size_t size, opal_accelerator_transfer_type_t type); +static int mca_accelerator_rocm_memmove_async(int dest_dev_id, int src_dev_id, void *dest, const void *src, size_t size, opal_accelerator_stream_t *stream, opal_accelerator_transfer_type_t type); static int mca_accelerator_rocm_memmove(int dest_dev_id, int src_dev_id, void *dest, const void *src, size_t size, opal_accelerator_transfer_type_t type); static int mca_accelerator_rocm_mem_alloc(int dev_id, void **ptr, size_t size); static int mca_accelerator_rocm_mem_release(int dev_id, void *ptr); +static int mca_accelerator_rocm_mem_alloc_stream(int dev_id, void **ptr, size_t size, opal_accelerator_stream_t *stream); +static int mca_accelerator_rocm_mem_release_stream(int dev_id, void *ptr, opal_accelerator_stream_t *stream); static int mca_accelerator_rocm_get_address_range(int dev_id, const void *ptr, void **base, size_t *size); @@ -42,8 +47,16 @@ static int mca_accelerator_rocm_device_can_access_peer( int *access, int dev1, i static int mca_accelerator_rocm_get_buffer_id(int dev_id, const void *addr, opal_accelerator_buffer_id_t *buf_id); +static int mca_accelerator_rocm_wait_stream(opal_accelerator_stream_t *stream); + +static int mca_accelerator_rocm_get_num_devices(int *num_devices); + +static int mca_accelerator_rocm_get_mem_bw(int device, float *bw); + opal_accelerator_base_module_t opal_accelerator_rocm_module = { + mca_accelerator_rocm_get_default_stream, //DONE + mca_accelerator_rocm_check_addr, mca_accelerator_rocm_create_stream, @@ -54,9 +67,12 @@ opal_accelerator_base_module_t opal_accelerator_rocm_module = mca_accelerator_rocm_memcpy_async, mca_accelerator_rocm_memcpy, + mca_accelerator_rocm_memmove_async, //DONE mca_accelerator_rocm_memmove, mca_accelerator_rocm_mem_alloc, mca_accelerator_rocm_mem_release, + mca_accelerator_rocm_mem_alloc_stream, //DONE + mca_accelerator_rocm_mem_release_stream, //DONE mca_accelerator_rocm_get_address_range, mca_accelerator_rocm_host_register, @@ -66,7 +82,12 @@ opal_accelerator_base_module_t opal_accelerator_rocm_module = mca_accelerator_rocm_get_device_pci_attr, mca_accelerator_rocm_device_can_access_peer, - mca_accelerator_rocm_get_buffer_id + mca_accelerator_rocm_get_buffer_id, + + mca_accelerator_rocm_wait_stream, //DONE + mca_accelerator_rocm_get_num_devices, //DONE + + mca_accelerator_rocm_get_mem_bw }; @@ -95,6 +116,9 @@ static int mca_accelerator_rocm_check_addr (const void *addr, int *dev_id, uint6 //*flags |= MCA_ACCELERATOR_FLAGS_HOST_ATOMICS; /* First access on a device pointer triggers ROCM support lazy initialization. */ opal_accelerator_rocm_lazy_init(); + // on Frontier the host can access any device memory + *flags |= MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY; + *dev_id = srcAttr.device; ret = 1; #if HIP_VERSION >= 50731921 } else if (hipMemoryTypeUnified == srcAttr.type) { @@ -105,12 +129,24 @@ static int mca_accelerator_rocm_check_addr (const void *addr, int *dev_id, uint6 //*flags |= MCA_ACCELERATOR_FLAGS_HOST_LDSTR; //*flags |= MCA_ACCELERATOR_FLAGS_HOST_ATOMICS; ret = 1; + *dev_id = srcAttr.device; } } + //printf("mca_accelerator_rocm_check_addr %p dev %d ret %d\n", addr, *dev_id, ret); return ret; } +static int mca_accelerator_rocm_get_default_stream(int dev_id, opal_accelerator_stream_t **stream) +{ + int delayed_init = opal_accelerator_rocm_lazy_init(); + if (OPAL_UNLIKELY(0 != delayed_init)) { + return delayed_init; + } + *stream = &opal_accelerator_rocm_default_stream.base; + return OPAL_SUCCESS; +} + static int mca_accelerator_rocm_create_stream(int dev_id, opal_accelerator_stream_t **stream) { if (NULL == stream) { @@ -291,14 +327,14 @@ static int mca_accelerator_rocm_memcpy(int dest_dev_id, int src_dev_id, void *de if (opal_accelerator_rocm_memcpy_async) { err = hipMemcpyAsync(dest, src, size, hipMemcpyDefault, - opal_accelerator_rocm_MemcpyStream); + *opal_accelerator_rocm_MemcpyStream); if (hipSuccess != err ) { opal_output_verbose(10, opal_accelerator_base_framework.framework_output, "error starting async copy\n"); return OPAL_ERROR; } - err = hipStreamSynchronize(opal_accelerator_rocm_MemcpyStream); + err = hipStreamSynchronize(*opal_accelerator_rocm_MemcpyStream); if (hipSuccess != err ) { opal_output_verbose(10, opal_accelerator_base_framework.framework_output, "error synchronizing stream after async copy\n"); @@ -316,6 +352,43 @@ static int mca_accelerator_rocm_memcpy(int dest_dev_id, int src_dev_id, void *de return OPAL_SUCCESS; } + +static int mca_accelerator_rocm_memmove_async(int dest_dev_id, int src_dev_id, void *dest, const void *src, size_t size, opal_accelerator_stream_t *stream, opal_accelerator_transfer_type_t type) +{ + hipDeviceptr_t tmp; + hipError_t result; + void *ptr; + + int delayed_init = opal_accelerator_rocm_lazy_init(); + if (OPAL_UNLIKELY(0 != delayed_init)) { + return delayed_init; + } + + if (NULL == dest || NULL == src || size <= 0) { + return OPAL_ERR_BAD_PARAM; + } + + result = mca_accelerator_rocm_mem_alloc_stream(src_dev_id, &ptr, size, stream); + if (OPAL_UNLIKELY(OPAL_SUCCESS != result)) { + return OPAL_ERROR; + } + tmp = (hipDeviceptr_t)ptr; + result = hipMemcpyAsync(tmp, (hipDeviceptr_t) src, size, hipMemcpyDefault, *(hipStream_t*)stream->stream); + if (OPAL_UNLIKELY(hipSuccess != result)) { + opal_show_help("help-accelerator-rocm.txt", "hipMemcpyAsync failed", true, tmp, src, size, + result); + return OPAL_ERROR; + } + result = hipMemcpyAsync((hipDeviceptr_t) dest, tmp, size, hipMemcpyDefault, *(hipStream_t*)stream->stream); + if (OPAL_UNLIKELY(hipSuccess != result)) { + opal_show_help("help-accelerator-rocm.txt", "hipMemcpyAsync failed", true, dest, tmp, + size, result); + return OPAL_ERROR; + } + return mca_accelerator_rocm_mem_release_stream(src_dev_id, ptr, stream); +} + + static int mca_accelerator_rocm_memmove(int dest_dev_id, int src_dev_id, void *dest, const void *src, size_t size, opal_accelerator_transfer_type_t type) @@ -336,7 +409,7 @@ static int mca_accelerator_rocm_memmove(int dest_dev_id, int src_dev_id, void *d if (opal_accelerator_rocm_memcpy_async) { err = hipMemcpyAsync(tmp, src, size, hipMemcpyDefault, - opal_accelerator_rocm_MemcpyStream); + *opal_accelerator_rocm_MemcpyStream); if (hipSuccess != err ) { opal_output_verbose(10, opal_accelerator_base_framework.framework_output, "error in async memcpy for memmove\n"); @@ -344,14 +417,14 @@ static int mca_accelerator_rocm_memmove(int dest_dev_id, int src_dev_id, void *d } err = hipMemcpyAsync(dest, tmp, size, hipMemcpyDefault, - opal_accelerator_rocm_MemcpyStream); + *opal_accelerator_rocm_MemcpyStream); if (hipSuccess != err ) { opal_output_verbose(10, opal_accelerator_base_framework.framework_output, "error in async memcpy for memmove\n"); return OPAL_ERROR; } - err = hipStreamSynchronize(opal_accelerator_rocm_MemcpyStream); + err = hipStreamSynchronize(*opal_accelerator_rocm_MemcpyStream); if (hipSuccess != err ) { opal_output_verbose(10, opal_accelerator_base_framework.framework_output, "error synchronizing stream for memmove\n"); @@ -566,3 +639,107 @@ static int mca_accelerator_rocm_get_buffer_id(int dev_id, const void *addr, opal #endif return OPAL_SUCCESS; } + +static int mca_accelerator_rocm_mem_alloc_stream( + int dev_id, + void **addr, + size_t size, + opal_accelerator_stream_t *stream) +{ +//#if HIP_VERSION >= ??? //TODO + hipError_t result; + + int delayed_init = opal_accelerator_rocm_lazy_init(); + if (OPAL_UNLIKELY(0 != delayed_init)) { + return delayed_init; + } + + if (NULL == stream || NULL == addr || 0 == size) { + return OPAL_ERR_BAD_PARAM; + } + + /* Try to allocate the memory from a memory pool, if available */ + /* get the default pool */ + hipMemPool_t mpool; + result = hipDeviceGetDefaultMemPool(&mpool, dev_id); + if (hipSuccess == result) { + result = hipMallocFromPoolAsync(addr, size, mpool, *(hipStream_t*)stream->stream); + if (hipSuccess == result) { + return OPAL_SUCCESS; + } + } + if (hipErrorNotSupported != result) { + opal_show_help("help-accelerator-rocm.txt", "hipMallocFromPoolAsync failed", true, + OPAL_PROC_MY_HOSTNAME, result); + return OPAL_ERROR; + } + /* fall-back to regular stream allocation */ + + result = hipMallocAsync(addr, size, *(hipStream_t*)stream->stream); + if (OPAL_UNLIKELY(hipSuccess != result)) { + opal_show_help("help-accelerator-rocm.txt", "hipMalloc failed", true, + OPAL_PROC_MY_HOSTNAME, result); + return OPAL_ERROR; + } + return OPAL_SUCCESS; +//#else +// return mca_accelerator_rocm_mem_alloc(dev_id, addr, size); +//#endif // HIP_VERSION +} + +static int mca_accelerator_rocm_mem_release_stream( + int dev_id, + void *addr, + opal_accelerator_stream_t *stream) +{ +//#if HIP_VERSION >= ??? //TODO + hipError_t result; + + if (NULL == stream || NULL == addr) { + return OPAL_ERR_BAD_PARAM; + } + + result = hipFreeAsync(addr, *(hipStream_t*)stream->stream); + if (OPAL_UNLIKELY(hipSuccess != result)) { + opal_show_help("help-accelerator-rocm.txt", "hipMalloc failed", true, + OPAL_PROC_MY_HOSTNAME, result); + return OPAL_ERROR; + } + return OPAL_SUCCESS; +//#else + /* wait for everything on the device to complete */ +// mca_accelerator_rocm_wait_stream(stream); +// return mca_accelerator_rocm_mem_release(dev_id, addr); +//#endif // HIP_VERSION >= 11020 +} + +static int mca_accelerator_rocm_wait_stream(opal_accelerator_stream_t *stream) +{ + hipError_t result; + result = hipStreamSynchronize(*(hipStream_t*)stream->stream); + if (OPAL_UNLIKELY(hipSuccess != result)) { + opal_show_help("help-accelerator-rocm.txt", "hipStreamSynchronize failed", true, + OPAL_PROC_MY_HOSTNAME, result); + return OPAL_ERROR; + } + return OPAL_SUCCESS; +} + + +static int mca_accelerator_rocm_get_num_devices(int *num_devices) +{ + *num_devices = opal_accelerator_rocm_num_devices; + return OPAL_SUCCESS; +} + +static int mca_accelerator_rocm_get_mem_bw(int device, float *bw) +{ + int delayed_init = opal_accelerator_rocm_lazy_init(); + if (OPAL_UNLIKELY(0 != delayed_init)) { + return delayed_init; + } + assert(opal_accelerator_rocm_mem_bw != NULL); + + *bw = opal_accelerator_rocm_mem_bw[device]; + return OPAL_SUCCESS; +} diff --git a/opal/mca/allocator/devicebucket/Makefile.am b/opal/mca/allocator/devicebucket/Makefile.am new file mode 100644 index 00000000000..466aad2671a --- /dev/null +++ b/opal/mca/allocator/devicebucket/Makefile.am @@ -0,0 +1,49 @@ +# +# Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana +# University Research and Technology +# Corporation. All rights reserved. +# Copyright (c) 2004-2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# Copyright (c) 2004-2009 High Performance Computing Center Stuttgart, +# University of Stuttgart. All rights reserved. +# Copyright (c) 2004-2005 The Regents of the University of California. +# All rights reserved. +# Copyright (c) 2010 Cisco Systems, Inc. All rights reserved. +# Copyright (c) 2017 IBM Corporation. All rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +dist_opaldata_DATA = help-mca-allocator-devicebucket.txt + +sources = \ + allocator_devicebucket.c \ + allocator_devicebucket_alloc.c \ + allocator_devicebucket_alloc.h + +# Make the output library in this directory, and name it either +# mca__.la (for DSO builds) or libmca__.la +# (for static builds). + +if MCA_BUILD_opal_allocator_devicebucket_DSO +component_noinst = +component_install = mca_allocator_devicebucket.la +else +component_noinst = libmca_allocator_devicebucket.la +component_install = +endif + +mcacomponentdir = $(opallibdir) +mcacomponent_LTLIBRARIES = $(component_install) +mca_allocator_devicebucket_la_SOURCES = $(sources) +mca_allocator_devicebucket_la_LDFLAGS = -module -avoid-version +mca_allocator_devicebucket_la_LIBADD = $(top_builddir)/opal/lib@OPAL_LIB_NAME@.la + +noinst_LTLIBRARIES = $(component_noinst) +libmca_allocator_devicebucket_la_SOURCES = $(sources) +libmca_allocator_devicebucket_la_LDFLAGS = -module -avoid-version + diff --git a/opal/mca/allocator/devicebucket/allocator_devicebucket.c b/opal/mca/allocator/devicebucket/allocator_devicebucket.c new file mode 100644 index 00000000000..e5b4f6d26f5 --- /dev/null +++ b/opal/mca/allocator/devicebucket/allocator_devicebucket.c @@ -0,0 +1,137 @@ +/* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2004-2007 The Trustees of Indiana University and Indiana + * University Research and Technology + * Corporation. All rights reserved. + * Copyright (c) 2004-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart, + * University of Stuttgart. All rights reserved. + * Copyright (c) 2004-2005 The Regents of the University of California. + * All rights reserved. + * Copyright (c) 2008 Cisco Systems, Inc. All rights reserved. + * Copyright (c) 2014 Los Alamos National Security, LLC. All rights + * reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "opal_config.h" +#include "opal/constants.h" +#include "opal/mca/allocator/allocator.h" +#include "opal/mca/allocator/devicebucket/allocator_devicebucket_alloc.h" +#include "opal/mca/base/mca_base_var.h" + +OBJ_CLASS_INSTANCE(mca_allocator_devicebucket_chunk_t, opal_list_item_t, NULL, NULL); + +struct mca_allocator_base_module_t *mca_allocator_devicebucket_module_init( + bool enable_mpi_threads, mca_allocator_base_component_segment_alloc_fn_t segment_alloc, + mca_allocator_base_component_segment_free_fn_t segment_free, void *context); + +int mca_allocator_devicebucket_module_open(void); + +int mca_allocator_devicebucket_module_close(void); + +void *mca_allocator_devicebucket_alloc_wrapper(struct mca_allocator_base_module_t *allocator, size_t size, + size_t align); + +static size_t mca_allocator_min_cache_size; +static size_t mca_allocator_max_cache_size; + +int mca_allocator_devicebucket_finalize(struct mca_allocator_base_module_t *allocator) +{ + mca_allocator_devicebucket_t *mem_options = (mca_allocator_devicebucket_t *) allocator; + + mca_allocator_devicebucket_cleanup(allocator); + + OBJ_DESTRUCT(&mem_options->used_chunks); + + free(mem_options->buckets); + free(allocator); + + return (OPAL_SUCCESS); +} + +struct mca_allocator_base_module_t *mca_allocator_devicebucket_module_init( + bool enable_mpi_threads, mca_allocator_base_component_segment_alloc_fn_t segment_alloc, + mca_allocator_base_component_segment_free_fn_t segment_free, void *context) +{ + size_t alloc_size = sizeof(mca_allocator_devicebucket_t); + mca_allocator_devicebucket_t *retval; + mca_allocator_devicebucket_t *allocator = (mca_allocator_devicebucket_t *) malloc(alloc_size); + if (NULL == allocator) { + return NULL; + } + retval = mca_allocator_devicebucket_init((mca_allocator_base_module_t *) allocator, + mca_allocator_min_cache_size, mca_allocator_max_cache_size, + segment_alloc, segment_free); + if (NULL == retval) { + free(allocator); + return NULL; + } + allocator->super.alc_alloc = mca_allocator_devicebucket_alloc_wrapper; + allocator->super.alc_realloc = NULL; // not supported + allocator->super.alc_free = mca_allocator_devicebucket_free; + allocator->super.alc_compact = mca_allocator_devicebucket_cleanup; + allocator->super.alc_finalize = mca_allocator_devicebucket_finalize; + allocator->super.alc_context = context; + return (mca_allocator_base_module_t *) allocator; +} + +static int mca_allocator_devicebucket_module_register(void) +{ + mca_allocator_min_cache_size = 4*1024; // 4K + mca_allocator_max_cache_size = 1*1024*1024*1024; // 1G + (void) mca_base_component_var_register(&mca_allocator_devicebucket_component.allocator_version, + "min_cache_size", "Minimum allocation cache size", + MCA_BASE_VAR_TYPE_SIZE_T, NULL, 0, + MCA_BASE_VAR_FLAG_SETTABLE, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_LOCAL, &mca_allocator_min_cache_size); + + (void) mca_base_component_var_register(&mca_allocator_devicebucket_component.allocator_version, + "max_cache_size", + "Maximum allocation cache size. Larger allocations will not be cached.", + MCA_BASE_VAR_TYPE_SIZE_T, NULL, 0, + MCA_BASE_VAR_FLAG_SETTABLE, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_LOCAL, &mca_allocator_max_cache_size); + return OPAL_SUCCESS; +} + +int mca_allocator_devicebucket_module_open(void) +{ + return OPAL_SUCCESS; +} + +int mca_allocator_devicebucket_module_close(void) +{ + return OPAL_SUCCESS; +} + +void *mca_allocator_devicebucket_alloc_wrapper(struct mca_allocator_base_module_t *allocator, size_t size, + size_t align) +{ + if (0 == align) { + return mca_allocator_devicebucket_alloc(allocator, size); + } + return mca_allocator_devicebucket_alloc_align(allocator, size, align); +} + +mca_allocator_base_component_t mca_allocator_devicebucket_component = { + + /* First, the mca_base_module_t struct containing meta information + about the module itself */ + + {MCA_ALLOCATOR_BASE_VERSION_2_0_0, + + "devicebucket", /* MCA module name */ + OPAL_MAJOR_VERSION, OPAL_MINOR_VERSION, OPAL_RELEASE_VERSION, + mca_allocator_devicebucket_module_open, /* module open */ + mca_allocator_devicebucket_module_close, /* module close */ + NULL, mca_allocator_devicebucket_module_register}, + {/* The component is checkpoint ready */ + MCA_BASE_METADATA_PARAM_CHECKPOINT}, + mca_allocator_devicebucket_module_init}; diff --git a/opal/mca/allocator/devicebucket/allocator_devicebucket_alloc.c b/opal/mca/allocator/devicebucket/allocator_devicebucket_alloc.c new file mode 100644 index 00000000000..46bafe4cdae --- /dev/null +++ b/opal/mca/allocator/devicebucket/allocator_devicebucket_alloc.c @@ -0,0 +1,209 @@ +/* + * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana + * University Research and Technology + * Corporation. All rights reserved. + * Copyright (c) 2004-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart, + * University of Stuttgart. All rights reserved. + * Copyright (c) 2004-2005 The Regents of the University of California. + * All rights reserved. + * Copyright (c) 2007 IBM Corp., All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "opal_config.h" +#include "opal/mca/allocator/devicebucket/allocator_devicebucket_alloc.h" +#include "opal/constants.h" +#include "opal/util/show_help.h" + +/** + * The define controls the size in bytes of the 1st bucket and hence every one + * afterwards. + */ +#define MCA_ALLOCATOR_BUCKET_1_SIZE 8 +/** + * This is the number of left bit shifts from 1 needed to get to the number of + * bytes in the initial memory buckets + */ +#define MCA_ALLOCATOR_BUCKET_1_BITSHIFTS 3 + +static int max_devicebucket_idx; + +/* + * Initializes the mca_allocator_devicebucket_options_t data structure for the passed + * parameters. + */ +mca_allocator_devicebucket_t * +mca_allocator_devicebucket_init(mca_allocator_base_module_t *mem, + size_t min_cache_size, size_t max_cache_size, + mca_allocator_base_component_segment_alloc_fn_t get_mem_funct, + mca_allocator_base_component_segment_free_fn_t free_mem_funct) +{ + mca_allocator_devicebucket_t *mem_options = (mca_allocator_devicebucket_t *) mem; + size_t size; + /* if a bad value is used for the number of buckets, default to 30 */ + int num_buckets = 1; + /* round min_cache_size down to pow2 */ + size = 1; + while (size < min_cache_size) { + size <<= 1; + } + min_cache_size = size; + while (size < max_cache_size) { + size <<= 1; + num_buckets++; + } + //printf("min_cache_size %zu max_cache_size %zu num_buckets %d\n", min_cache_size, max_cache_size, num_buckets); + max_devicebucket_idx = num_buckets - 1; + + /* initialize the array of buckets */ + size = sizeof(mca_allocator_devicebucket_bucket_t) * num_buckets; + mem_options->buckets = (mca_allocator_devicebucket_bucket_t *) malloc(size); + if (NULL == mem_options->buckets) { + return (NULL); + } + for (int i = 0; i < num_buckets; i++) { + OBJ_CONSTRUCT(&(mem_options->buckets[i].super), opal_lifo_t); + mem_options->buckets[i].size = (min_cache_size << i); + } + mem_options->num_buckets = num_buckets; + mem_options->get_mem_fn = get_mem_funct; + mem_options->free_mem_fn = free_mem_funct; + mem_options->min_cache_size = min_cache_size; + OBJ_CONSTRUCT(&mem_options->used_chunks, opal_hash_table_t); + opal_hash_table_init(&mem_options->used_chunks, 32); + OBJ_CONSTRUCT(&(mem_options->used_chunks_lock), opal_mutex_t); + return (mem_options); +} + +/* + * Accepts a request for memory in a specific region defined by the + * mca_allocator_devicebucket_options_t struct and returns a pointer to memory in that + * region or NULL if there was an error + * + */ +void *mca_allocator_devicebucket_alloc(mca_allocator_base_module_t *mem, size_t size) +{ + mca_allocator_devicebucket_t *mem_options = (mca_allocator_devicebucket_t *) mem; + /* initialize for the later bit shifts */ + int bucket_num = 0; + size_t bucket_size = mem_options->min_cache_size; + mca_allocator_devicebucket_chunk_t *chunk; + + /* figure out which bucket it will come from. */ + while (size > bucket_size) { + bucket_num++; + bucket_size <<= 1; + } + + //printf("mca_allocator_devicebucket_alloc checking bucket %d of %d for size %d\n", bucket_num, mem_options->num_buckets, bucket_size); + if (bucket_num >= mem_options->num_buckets) { + /* allocate directly */ + chunk = OBJ_NEW(mca_allocator_devicebucket_chunk_t); + chunk->addr = mem_options->get_mem_fn(mem_options->super.alc_context, &size); + chunk->size = size; + } else { + /* see if there is already a free chunk */ + chunk = (mca_allocator_devicebucket_chunk_t *)opal_lifo_pop(&(mem_options->buckets[bucket_num].super)); + if (NULL == chunk) { + /* create a new allocation */ + chunk = OBJ_NEW(mca_allocator_devicebucket_chunk_t); + if (NULL == chunk) { + return NULL; + } + chunk->addr = mem_options->get_mem_fn(mem_options->super.alc_context, &bucket_size); + chunk->size = bucket_size; + } + } + /* store the chunk in the hash table so we can find it during free */ + OPAL_THREAD_LOCK(&(mem_options->used_chunks_lock)); + opal_hash_table_set_value_uint64(&(mem_options->used_chunks), (uint64_t)chunk->addr, chunk); + OPAL_THREAD_UNLOCK(&(mem_options->used_chunks_lock)); + //printf("Allocated chunk %p for address %p\n", chunk, chunk->addr); + return chunk->addr; +} + +/* + * allocates an aligned region of memory + */ +void *mca_allocator_devicebucket_alloc_align(mca_allocator_base_module_t *mem, size_t size, + size_t alignment) +{ + return mca_allocator_devicebucket_alloc(mem, size); +} + +/* + * function to reallocate the segment of memory + */ +void *mca_allocator_devicebucket_realloc(mca_allocator_base_module_t *mem, void *ptr, size_t size) +{ + mca_allocator_devicebucket_t *mem_options = (mca_allocator_devicebucket_t *) mem; + // TODO: do something nice here + return NULL; +} + +/* + * Frees the passed region of memory + * + */ +void mca_allocator_devicebucket_free(mca_allocator_base_module_t *mem, void *ptr) +{ + mca_allocator_devicebucket_t *mem_options = (mca_allocator_devicebucket_t *) mem; + size_t bucket_size = mem_options->min_cache_size; + size_t allocated_size; + int bucket_num = 0; + mca_allocator_devicebucket_chunk_t *chunk; + + OPAL_THREAD_LOCK(&(mem_options->used_chunks_lock)); + opal_hash_table_get_value_uint64(&(mem_options->used_chunks), (uint64_t)ptr, (void**)&chunk); + if (NULL == chunk) { + printf("Couldn't find chunk for address %p\n", ptr); + OPAL_THREAD_UNLOCK(&(mem_options->used_chunks_lock)); + return; + } + opal_hash_table_remove_value_uint64(&(mem_options->used_chunks), (uint64_t)ptr); + OPAL_THREAD_UNLOCK(&(mem_options->used_chunks_lock)); + size_t size = chunk->size; + + /* figure out which bucket to put the chunk into. */ + while (size > bucket_size) { + bucket_num++; + bucket_size <<= 1; + } + + if (bucket_num >= mem_options->num_buckets) { + mem_options->free_mem_fn(mem_options->super.alc_context, ptr); + OBJ_RELEASE(chunk); + } else { + /* push into lifo */ + opal_lifo_push(&(mem_options->buckets[bucket_num].super), &chunk->super); + } +} + +/* + * Frees all the memory from all the buckets back to the system. Note that + * this function only frees memory that was previously freed with + * mca_allocator_devicebucket_free(). + * + */ +int mca_allocator_devicebucket_cleanup(mca_allocator_base_module_t *mem) +{ + mca_allocator_devicebucket_t *mem_options = (mca_allocator_devicebucket_t *) mem; + mca_allocator_devicebucket_chunk_t *chunk; + + for (int i = 0; i < mem_options->num_buckets; i++) { + while (NULL != (chunk = (mca_allocator_devicebucket_chunk_t *)opal_lifo_pop(&(mem_options->buckets[i].super)))) { + if (mem_options->free_mem_fn) { + mem_options->free_mem_fn(mem->alc_context, chunk->addr); + } + OBJ_RELEASE(chunk); + } + } + return OPAL_SUCCESS; +} diff --git a/opal/mca/allocator/devicebucket/allocator_devicebucket_alloc.h b/opal/mca/allocator/devicebucket/allocator_devicebucket_alloc.h new file mode 100644 index 00000000000..b313bd91d6f --- /dev/null +++ b/opal/mca/allocator/devicebucket/allocator_devicebucket_alloc.h @@ -0,0 +1,190 @@ +/* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana + * University Research and Technology + * Corporation. All rights reserved. + * Copyright (c) 2004-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart, + * University of Stuttgart. All rights reserved. + * Copyright (c) 2004-2005 The Regents of the University of California. + * All rights reserved. + * Copyright (c) 2015 Los Alamos National Security, LLC. All rights + * reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +/** @file + * A generic memory bucket allocator. + **/ + +#ifndef ALLOCATOR_DEVICEBUCKET_ALLOC_H +#define ALLOCATOR_DEVICEBUCKET_ALLOC_H + +#include "opal_config.h" +#include "opal/mca/allocator/allocator.h" +#include "opal/mca/threads/mutex.h" +#include "opal/class/opal_lifo.h" +#include "opal/class/opal_hash_table.h" +#include +#include + +BEGIN_C_DECLS + +/** + * Structure for the header of each memory chunk + */ +struct mca_allocator_devicebucket_chunk_t { + opal_list_item_t super; + void *addr; // address + size_t size; +}; + +/** + * Typedef so we don't have to use struct + */ +typedef struct mca_allocator_devicebucket_chunk_t mca_allocator_devicebucket_chunk_t; + +OPAL_DECLSPEC OBJ_CLASS_DECLARATION(mca_allocator_devicebucket_chunk_t); + +struct mca_allocator_devicebucket_bucket_t { + opal_lifo_t super; + size_t size; +}; + +/** + * Typedef so we don't have to use struct + */ +typedef struct mca_allocator_devicebucket_bucket_t mca_allocator_devicebucket_bucket_t; + +/** + * Structure that holds the necessary information for each area of memory + */ +struct mca_allocator_devicebucket_t { + mca_allocator_base_module_t super; /**< makes this a child of class mca_allocator_t */ + mca_allocator_devicebucket_bucket_t *buckets; /**< the array of buckets */ + int num_buckets; /**< the number of buckets */ + opal_hash_table_t used_chunks; + opal_mutex_t used_chunks_lock; + size_t min_cache_size; + mca_allocator_base_component_segment_alloc_fn_t get_mem_fn; + /**< pointer to the function to get more memory */ + mca_allocator_base_component_segment_free_fn_t free_mem_fn; + /**< pointer to the function to free memory */ +}; +/** + * Typedef so we don't have to use struct + */ +typedef struct mca_allocator_devicebucket_t mca_allocator_devicebucket_t; + +/** + * Initializes the mca_allocator_devicebucket_options_t data structure for the passed + * parameters. + * @param mem a pointer to the mca_allocator_t struct to be filled in + * @param num_buckets The number of buckets the allocator will use + * @param get_mem_funct A pointer to the function that the allocator + * will use to get more memory + * @param free_mem_funct A pointer to the function that the allocator + * will use to free memory + * + * @retval Pointer to the initialized mca_allocator_devicebucket_options_t structure + * @retval NULL if there was an error + */ +mca_allocator_devicebucket_t * +mca_allocator_devicebucket_init(mca_allocator_base_module_t *mem, + size_t min_cache_size, size_t max_cache_size, + mca_allocator_base_component_segment_alloc_fn_t get_mem_funct, + mca_allocator_base_component_segment_free_fn_t free_mem_funct); +/** + * Accepts a request for memory in a specific region defined by the + * mca_allocator_devicebucket_options_t struct and returns a pointer to memory in that + * region or NULL if there was an error + * + * @param mem A pointer to the appropriate struct for the area of memory. + * @param size The size of the requested area of memory + * + * @retval Pointer to the area of memory if the allocation was successful + * @retval NULL if the allocation was unsuccessful + */ +void *mca_allocator_devicebucket_alloc(mca_allocator_base_module_t *mem, size_t size); + +/** + * Accepts a request for memory in a specific region defined by the + * mca_allocator_devicebucket_options_t struct and aligned by the specified amount and + * returns a pointer to memory in that region or NULL if there was an error + * + * @param mem A pointer to the appropriate struct for the area of + * memory. + * @param size The size of the requested area of memory + * @param alignment The requested alignment of the new area of memory. This + * MUST be a power of 2. + * + * @retval Pointer to the area of memory if the allocation was successful + * @retval NULL if the allocation was unsuccessful + * + */ +void *mca_allocator_devicebucket_alloc_align(mca_allocator_base_module_t *mem, size_t size, + size_t alignment); + +/** + * Attempts to resize the passed region of memory into a larger or a smaller + * region. If it is unsuccessful, it will return NULL and the passed area of + * memory will be untouched. + * + * @param mem A pointer to the appropriate struct for the area of + * memory. + * @param size The size of the requested area of memory + * @param ptr A pointer to the region of memory to be resized + * + * @retval Pointer to the area of memory if the reallocation was successful + * @retval NULL if the allocation was unsuccessful + * + */ +void *mca_allocator_devicebucket_realloc(mca_allocator_base_module_t *mem, void *ptr, size_t size); + +/** + * Frees the passed region of memory + * + * @param mem A pointer to the appropriate struct for the area of + * memory. + * @param ptr A pointer to the region of memory to be freed + * + * @retval None + * + */ +void mca_allocator_devicebucket_free(mca_allocator_base_module_t *mem, void *ptr); + +/** + * Frees all the memory from all the buckets back to the system. Note that + * this function only frees memory that was previously freed with + * mca_allocator_devicebucket_free(). + * + * @param mem A pointer to the appropriate struct for the area of + * memory. + * + * @retval None + * + */ +int mca_allocator_devicebucket_cleanup(mca_allocator_base_module_t *mem); + +/** + * Cleanup all resources held by this allocator. + * + * @param mem A pointer to the appropriate struct for the area of + * memory. + * + * @retval None + * + */ +int mca_allocator_devicebucket_finalize(mca_allocator_base_module_t *mem); + +OPAL_DECLSPEC extern mca_allocator_base_component_t mca_allocator_devicebucket_component; + +END_C_DECLS + +#endif /* ALLOCATOR_DEVICEBUCKET_ALLOC_H */ diff --git a/opal/mca/allocator/devicebucket/help-mca-allocator-devicebucket.txt b/opal/mca/allocator/devicebucket/help-mca-allocator-devicebucket.txt new file mode 100644 index 00000000000..01c152fd26d --- /dev/null +++ b/opal/mca/allocator/devicebucket/help-mca-allocator-devicebucket.txt @@ -0,0 +1,21 @@ +# -*- text -*- +# +# Copyright (c) 2004-2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# +# This is the US/English help file for Open MPI's allocator bucket support +# +[buffer too large] +ERROR: Requested buffer size %zu exceeds limit of %zu +Consider setting "%s" to %d +# +[aligned buffer too large] +ERROR: Requested aligned buffer size %zu exceeds limit of %zu +Consider setting "%s" to %d +# diff --git a/opal/mca/allocator/devicebucket/owner.txt b/opal/mca/allocator/devicebucket/owner.txt new file mode 100644 index 00000000000..c47a2d510b1 --- /dev/null +++ b/opal/mca/allocator/devicebucket/owner.txt @@ -0,0 +1,7 @@ +# +# owner/status file +# owner: institution that is responsible for this package +# status: e.g. active, maintenance, unmaintained +# +owner: UTK +status: maintenance diff --git a/opal/mca/dl/dlopen/configure.m4 b/opal/mca/dl/dlopen/configure.m4 index 07fda820016..4ae625b1fb5 100644 --- a/opal/mca/dl/dlopen/configure.m4 +++ b/opal/mca/dl/dlopen/configure.m4 @@ -27,7 +27,7 @@ AC_DEFUN([MCA_opal_dl_dlopen_CONFIG],[ AC_CONFIG_FILES([opal/mca/dl/dlopen/Makefile]) OAC_CHECK_PACKAGE([dlopen], - [dl_dlopen], + [opal_dl_dlopen], [dlfcn.h], [dl], [dlopen], @@ -38,5 +38,5 @@ AC_DEFUN([MCA_opal_dl_dlopen_CONFIG],[ [$1], [$2]) - AC_SUBST(dl_dlopen_LIBS) + AC_SUBST(opal_dl_dlopen_LIBS) ]) diff --git a/opal/mca/threads/base/threads_base.c b/opal/mca/threads/base/threads_base.c index 227aaeb64d3..d3ae6b5ccaf 100644 --- a/opal/mca/threads/base/threads_base.c +++ b/opal/mca/threads/base/threads_base.c @@ -24,6 +24,7 @@ #include "opal/constants.h" #include "opal/mca/threads/base/base.h" +#include "opal/mca/threads/threads.h" #if OPAL_ENABLE_DEBUG bool opal_debug_threads = false; diff --git a/test/datatype/Makefile.am b/test/datatype/Makefile.am index 3d6fd3289b5..9ae81e3dff9 100644 --- a/test/datatype/Makefile.am +++ b/test/datatype/Makefile.am @@ -96,8 +96,11 @@ unpack_hetero_LDFLAGS = $(OMPI_PKG_CONFIG_LDFLAGS) unpack_hetero_LDADD = \ $(top_builddir)/opal/lib@OPAL_LIB_NAME@.la + reduce_local_SOURCES = reduce_local.c -reduce_local_LDFLAGS = $(OMPI_PKG_CONFIG_LDFLAGS) +reduce_local_CPPFLAGS= $(accelerator_cudart_CPPFLAGS) $(accelerator_cuda_CPPFLAGS) +reduce_local_LDFLAGS = $(OMPI_PKG_CONFIG_LDFLAGS)\ + $(accelerator_cuda_LDFLAGS) $(accelerator_cudart_LDFLAGS) reduce_local_LDADD = \ $(top_builddir)/ompi/lib@OMPI_LIBMPI_NAME@.la \ $(top_builddir)/opal/lib@OPAL_LIB_NAME@.la diff --git a/test/datatype/reduce_local.c b/test/datatype/reduce_local.c index 17259cd2b18..8ec54a243d8 100644 --- a/test/datatype/reduce_local.c +++ b/test/datatype/reduce_local.c @@ -20,11 +20,17 @@ #include #include +// TODO: detect through configure +//#define HAVE_CUDA 1 +//#define HAVE_ROCM 1 + #include "mpi.h" #include "ompi/communicator/communicator.h" #include "ompi/datatype/ompi_datatype.h" #include "ompi/runtime/mpiruntime.h" +#include "opal_config.h" + typedef struct op_name_s { char *name; char *mpi_op_name; @@ -64,20 +70,33 @@ static int total_errors = 0; _a < _b ? _a : _b; \ }) +static void print_header(int max_shift) { + printf("%-10s %-10s %-10s %-10s %-10s", "Op", "Type", "TypeSize", "Check", "Count"); + if (1 == max_shift) { + printf("%-10s", "'Time (seconds)'"); + } else { + for (int i = 0; i < max_shift; ++i) { + char str[128]; + snprintf(str, 128, "'Shift %d [s]'", i); + printf(" %-10s ", str); + } + } + printf("\n"); +} + static void print_status(char *op, char *type, int type_size, int count, int max_shift, double *duration, int repeats, int correct) { if (correct) { - printf("%-10s %s %-10d%s ", op, type, type_size, - (verbose ? " [\033[1;32msuccess\033[0m]" : "")); + printf("%-10s %s %-10d success", op, type, type_size); } else { - printf("%-10s %s [\033[1;31mfail\033[0m]", op, type); + printf("%-10s %s %-10d [\033[1;31mfail\033[0m]", op, type, type_size); total_errors++; } if (1 == max_shift) { - printf(" count %-10d time (seconds) %.8f seconds\n", count, duration[0] / repeats); + printf(" %-10d %.8f\n", count, duration[0] / repeats); } else { - printf(" count %-10d time (seconds / shifts) ", count); + printf(" %-10d ", count); for (int i = 0; i < max_shift; i++) { printf("%.8f ", duration[i] / repeats); } @@ -123,25 +142,25 @@ static int build_do_ops(char *optarg, int *do_ops) } /* clang-format off */ -#define MPI_OP_TEST(OPNAME, MPIOP, MPITYPE, TYPE, INBUF, INOUT_BUF, CHECK_BUF, COUNT, TYPE_PREFIX) \ +#define MPI_OP_TEST(OPNAME, MPIOP, MPITYPE, TYPE, INIT_INBUF, INBUF, INIT_INOUT_BUF, INOUT_BUF, CHECK_BUF, COUNT, TYPE_PREFIX) \ do { \ - const TYPE *_p1 = ((TYPE*)(INBUF)), *_p3 = ((TYPE*)(CHECK_BUF)); \ - TYPE *_p2 = ((TYPE*)(INOUT_BUF)); \ skip_op_type = 0; \ + allocator->memcpy(INBUF, INIT_INBUF, sizeof(TYPE) * (COUNT)); \ for(int _k = 0; _k < min((COUNT), max_shift); +_k++ ) { \ duration[_k] = 0.0; \ for(int _r = repeats; _r > 0; _r--) { \ - memcpy(_p2, _p3, sizeof(TYPE) * (COUNT)); \ + allocator->memcpy(INOUT_BUF, INIT_INOUT_BUF, sizeof(TYPE) * (COUNT)); \ tstart = MPI_Wtime(); \ - MPI_Reduce_local(_p1+_k, _p2+_k, (COUNT)-_k, (MPITYPE), (MPIOP)); \ + MPI_Reduce_local(INBUF+_k, INOUT_BUF+_k, (COUNT)-_k, (MPITYPE), (MPIOP)); \ tend = MPI_Wtime(); \ duration[_k] += (tend - tstart); \ if( check ) { \ + allocator->memcpy(CHECK_BUF, INOUT_BUF, sizeof(TYPE) * (COUNT)); \ for( i = 0; i < (COUNT)-_k; i++ ) { \ - if(((_p2+_k)[i]) == (((_p1+_k)[i]) OPNAME ((_p3+_k)[i]))) \ + if(((CHECK_BUF+_k)[i]) == (((INIT_INBUF+_k)[i]) OPNAME ((INIT_INOUT_BUF+_k)[i]))) \ continue; \ printf("First error at alignment %d position %d (%" TYPE_PREFIX " %s %" TYPE_PREFIX " != %" TYPE_PREFIX ")\n", \ - _k, i, (_p1+_k)[i], (#OPNAME), (_p3+_k)[i], (_p2+_k)[i]); \ + _k, i, (INIT_INBUF+_k)[i], (#OPNAME), (INIT_INOUT_BUF+_k)[i], (INIT_INOUT_BUF+_k)[i]); \ correctness = 0; \ break; \ } \ @@ -151,26 +170,26 @@ do { \ goto check_and_continue; \ } while (0) -#define MPI_OP_MINMAX_TEST(OPNAME, MPIOP, MPITYPE, TYPE, INBUF, INOUT_BUF, CHECK_BUF, COUNT, TYPE_PREFIX) \ +#define MPI_OP_MINMAX_TEST(OPNAME, MPIOP, MPITYPE, TYPE, INIT_INBUF, INBUF, INIT_INOUT_BUF, INOUT_BUF, CHECK_BUF, COUNT, TYPE_PREFIX) \ do { \ - const TYPE *_p1 = ((TYPE*)(INBUF)), *_p3 = ((TYPE*)(CHECK_BUF)); \ - TYPE *_p2 = ((TYPE*)(INOUT_BUF)); \ skip_op_type = 0; \ + allocator->memcpy(INBUF, INIT_INBUF, sizeof(TYPE) * (COUNT)); \ for(int _k = 0; _k < min((COUNT), max_shift); +_k++ ) { \ duration[_k] = 0.0; \ for(int _r = repeats; _r > 0; _r--) { \ - memcpy(_p2, _p3, sizeof(TYPE) * (COUNT)); \ + allocator->memcpy(INOUT_BUF, INIT_INOUT_BUF, sizeof(TYPE) * (COUNT)); \ tstart = MPI_Wtime(); \ - MPI_Reduce_local(_p1+_k, _p2+_k, (COUNT)-_k, (MPITYPE), (MPIOP)); \ + MPI_Reduce_local(INBUF+_k, INOUT_BUF+_k, (COUNT)-_k, (MPITYPE), (MPIOP)); \ tend = MPI_Wtime(); \ duration[_k] += (tend - tstart); \ if( check ) { \ + allocator->memcpy(CHECK_BUF, INOUT_BUF, sizeof(TYPE) * (COUNT)); \ for( i = 0; i < (COUNT); i++ ) { \ - TYPE _v1 = *(_p1+_k), _v2 = *(_p2+_k), _v3 = *(_p3+_k); \ + TYPE _v1 = *(INIT_INBUF+_k), _v2 = *(CHECK_BUF+_k), _v3 = *(INIT_INOUT_BUF+_k); \ if(_v2 == OPNAME(_v1, _v3)) \ continue; \ printf("First error at alignment %d position %d (%" TYPE_PREFIX " != %s(%" TYPE_PREFIX ", %" TYPE_PREFIX ")\n", \ - _k, i, _v1, (#OPNAME), _v3, _v2); \ + _k, i, _v2, (#OPNAME), _v1, _v2); \ correctness = 0; \ break; \ } \ @@ -181,19 +200,142 @@ do { \ } while (0) /* clang-format on */ +static +void *host_allocate(size_t size, size_t align) { + void *ptr; + posix_memalign(&ptr, align, size); + return ptr; +} +static void host_free(void *ptr) { + free(ptr); +} +static void host_init(void) { + // nothing to do +} +static void host_fini(void) { + // nothing to do +} +static void* host_memcpy(void *dst, const void *src, size_t size) { + return memcpy(dst, src, size); +} + +typedef void*(*allocate_fn_t)(size_t, size_t); +typedef void*(*memcpy_fn_t)(void*, const void*, size_t); +typedef void(*free_fn_t)(void*); +typedef void(*init_fn_t)(void); +typedef void(*fini_fn_t)(void); + +enum ALLOCATOR_FLAGS { + ALLOCATOR_DISCRETE = 1, +}; + +typedef struct { + int flags; + init_fn_t init; + allocate_fn_t allocate; + memcpy_fn_t memcpy; + free_fn_t free; + fini_fn_t fini; +} allocator_t; + +static allocator_t host_allocator = { + .flags = 0, + .init = &host_init, + .allocate = &host_allocate, + .memcpy = &host_memcpy, + .free = &host_free, + .fini = &host_fini}; + +#if defined(OPAL_CUDA_SUPPORT) && OPAL_CUDA_SUPPORT +#include +static void cuda_init() { + // nothing to be done +} +static void *cuda_allocate(size_t size, size_t align) { + (void)align; // ignored + void *ptr; + int err; + if (cudaSuccess != (err = cudaMalloc(&ptr, size))) { + fprintf(stderr, "cudaMalloc failed to allocate %zuB: %s", size, cudaGetErrorName(err)); + return NULL; + } + return ptr; +} +static void* cuda_memcpy(void *dst, const void *src, size_t size) { + cudaMemcpy(dst, src, size, cudaMemcpyDefault); + cudaDeviceSynchronize(); + return dst; +} +static void cuda_free(void *ptr) { + cudaFree(ptr); +} +static void cuda_fini() { + // nothing to be done +} +static allocator_t cuda_allocator = { + .flags = ALLOCATOR_DISCRETE, + .init = &cuda_init, + .allocate = &cuda_allocate, + .memcpy = &cuda_memcpy, + .free = &cuda_free, + .fini = &cuda_fini}; + +#elif defined(OPAL_ROCM_SUPPORT) && OPAL_ROCM_SUPPORT +#include +static void rocm_init() { + hipError_t ret = hipInit(0); + assert(hipSuccess == ret); + int num_devs = 0; + ret = hipGetDeviceCount(&num_devs); + assert(hipSuccess == ret); + assert(num_devs > 0); + ret = hipSetDevice(0); + assert(hipSuccess == ret); +} +static void *rocm_allocate(size_t size, size_t align) { + (void)align; // ignored + void *ptr; + int err; + if (hipSuccess != (err = hipMalloc(&ptr, size))) { + fprintf(stderr, "hipMalloc failed to allocate %zuB: %s", size, hipGetErrorName(err)); + return NULL; + } + return ptr; +} +static void* rocm_memcpy(void *dst, const void *src, size_t size) { + hipMemcpy(dst, src, size, hipMemcpyDefault); + return dst; +} +static void rocm_free(void *ptr) { + hipFree(ptr); +} +static void rocm_fini() { + // nothing to be done +} +static allocator_t rocm_allocator = { + .flags = ALLOCATOR_DISCRETE, + .init = &rocm_init, + .allocate = &rocm_allocate, + .memcpy = &rocm_memcpy, + .free = &rocm_free, + .fini = &rocm_fini}; + +#endif + int main(int argc, char **argv) { - static void *in_buf = NULL, *inout_buf = NULL, *inout_check_buf = NULL; + static void *in_buf = NULL, *inout_buf = NULL, *inout_check_buf = NULL, *init_in_buf = NULL, *init_inout_buf = NULL; int count, type_size = 8, rank, size, provided, correctness = 1; int repeats = 1, i, c, op1_alignment = 0, res_alignment = 0; int max_shift = 4; double *duration, tstart, tend; + allocator_t *allocator = &host_allocator; bool check = true; char type[5] = "uifd", *op = "sum", *mpi_type; int lower = 1, upper = 1000000, skip_op_type; MPI_Op mpi_op; - while (-1 != (c = getopt(argc, argv, "l:u:r:t:o:i:s:n:1:2:vfh"))) { + while (-1 != (c = getopt(argc, argv, "l:u:r:t:o:i:s:n:1:2:d:vfh"))) { switch (c) { case 'l': lower = atoi(optarg); @@ -267,6 +409,26 @@ int main(int argc, char **argv) exit(-1); } break; + case 'd': + if (0 == strncmp("host", optarg, 4)) { + // default allocator + break; + } else +#if defined(OPAL_CUDA_SUPPORT) && OPAL_CUDA_SUPPORT + if (0 == strncmp("cuda", optarg, 4)) { + allocator = &cuda_allocator; + break; + } else +#elif defined(OPAL_ROCM_SUPPORT) && OPAL_ROCM_SUPPORT + if (0 == strncmp("rocm", optarg, 4)) { + allocator = &rocm_allocator; + break; + } else +#endif + { + fprintf(stderr, "Unsupported allocator: %s\n", optarg); + // fall-through + } case 'h': fprintf(stdout, "%s options are:\n" @@ -277,6 +439,14 @@ int main(int argc, char **argv) " -r : number of repetitions for each test\n" " -o : comma separated list of operations to execute among\n" " sum, min, max, prod, bor, bxor, band\n" + " -d : host" +#if defined(OPAL_CUDA_SUPPORT) && OPAL_CUDA_SUPPORT + ", cuda" +#endif +#if defined(OPAL_ROCM_SUPPORT) && OPAL_ROCM_SUPPORT + ", rocm" +#endif + "\n" " -i : shift on all buffers to check alignment\n" " -1 : (mis)alignment in elements for the first op\n" " -2 : (mis)alignment in elements for the result\n" @@ -291,18 +461,23 @@ int main(int argc, char **argv) if (!do_ops_built) { /* not yet done, take the default */ build_do_ops("all", do_ops); } - posix_memalign(&in_buf, 64, (upper + op1_alignment) * sizeof(double)); - posix_memalign(&inout_buf, 64, (upper + res_alignment) * sizeof(double)); - posix_memalign(&inout_check_buf, 64, upper * sizeof(double)); - duration = (double *) malloc(max_shift * sizeof(double)); - ompi_mpi_init(argc, argv, MPI_THREAD_SERIALIZED, &provided, false); + allocator->init(); + in_buf = allocator->allocate((upper + op1_alignment) * sizeof(double), 64); + inout_buf = allocator->allocate((upper + op1_alignment) * sizeof(double), 64); + init_in_buf = malloc((upper + op1_alignment) * sizeof(double)); + init_inout_buf = malloc((upper + op1_alignment) * sizeof(double)); + duration = (double *) malloc(max_shift * sizeof(double)); + inout_check_buf = malloc(upper * sizeof(double)); + rank = ompi_comm_rank(MPI_COMM_WORLD); (void) rank; size = ompi_comm_size(MPI_COMM_WORLD); (void) size; + print_header(max_shift); + for (uint32_t type_idx = 0; type_idx < strlen(type); type_idx++) { for (uint32_t op_idx = 0; do_ops[op_idx] >= 0; op_idx++) { op = array_of_ops[do_ops[op_idx]].name; @@ -318,39 +493,55 @@ int main(int argc, char **argv) + op1_alignment * sizeof(int8_t)), *inout_int8 = (int8_t *) ((char *) inout_buf + res_alignment * sizeof(int8_t)), - *inout_int8_for_check = (int8_t *) inout_check_buf; + *inout_int8_for_check = (int8_t *) inout_check_buf, + *init_inout_int8 = (int8_t *)init_inout_buf, + *init_in_int8 = (int8_t *)init_in_buf; for (i = 0; i < count; i++) { - in_int8[i] = 5; - inout_int8[i] = inout_int8_for_check[i] = -3; + init_in_int8[i] = 5; + init_inout_int8[i] = -3; } mpi_type = "MPI_INT8_T"; if (0 == strcmp(op, "sum")) { - MPI_OP_TEST(+, mpi_op, MPI_INT8_T, int8_t, in_int8, inout_int8, + MPI_OP_TEST(+, mpi_op, MPI_INT8_T, int8_t, + init_in_int8, in_int8, + init_inout_int8, inout_int8, inout_int8_for_check, count, PRId8); } if (0 == strcmp(op, "bor")) { - MPI_OP_TEST(|, mpi_op, MPI_INT8_T, int8_t, in_int8, inout_int8, + MPI_OP_TEST(|, mpi_op, MPI_INT8_T, int8_t, + init_in_int8, in_int8, + init_inout_int8, inout_int8, inout_int8_for_check, count, PRId8); } if (0 == strcmp(op, "bxor")) { - MPI_OP_TEST(^, mpi_op, MPI_INT8_T, int8_t, in_int8, inout_int8, + MPI_OP_TEST(^, mpi_op, MPI_INT8_T, int8_t, + init_in_int8, in_int8, + init_inout_int8, inout_int8, inout_int8_for_check, count, PRId8); } if (0 == strcmp(op, "prod")) { - MPI_OP_TEST(*, mpi_op, MPI_INT8_T, int8_t, in_int8, inout_int8, + MPI_OP_TEST(*, mpi_op, MPI_INT8_T, int8_t, + init_in_int8, in_int8, + init_inout_int8, inout_int8, inout_int8_for_check, count, PRId8); } if (0 == strcmp(op, "band")) { - MPI_OP_TEST(&, mpi_op, MPI_INT8_T, int8_t, in_int8, inout_int8, + MPI_OP_TEST(&, mpi_op, MPI_INT8_T, int8_t, + init_in_int8, in_int8, + init_inout_int8, inout_int8, inout_int8_for_check, count, PRId8); } if (0 == strcmp(op, "max")) { - MPI_OP_MINMAX_TEST(max, mpi_op, MPI_INT8_T, int8_t, in_int8, inout_int8, + MPI_OP_MINMAX_TEST(max, mpi_op, MPI_INT8_T, int8_t, + init_in_int8, in_int8, + init_inout_int8, inout_int8, inout_int8_for_check, count, PRId8); } if (0 == strcmp(op, "min")) { // intentionally reversed in and out - MPI_OP_MINMAX_TEST(min, mpi_op, MPI_INT8_T, int8_t, in_int8, inout_int8, + MPI_OP_MINMAX_TEST(min, mpi_op, MPI_INT8_T, int8_t, + init_in_int8, in_int8, + init_inout_int8, inout_int8, inout_int8_for_check, count, PRId8); } } @@ -359,40 +550,56 @@ int main(int argc, char **argv) + op1_alignment * sizeof(int16_t)), *inout_int16 = (int16_t *) ((char *) inout_buf + res_alignment * sizeof(int16_t)), - *inout_int16_for_check = (int16_t *) inout_check_buf; + *inout_int16_for_check = (int16_t *) inout_check_buf, + *init_inout_int16 = (int16_t *)init_inout_buf, + *init_in_int16 = (int16_t *)init_in_buf; for (i = 0; i < count; i++) { - in_int16[i] = 5; - inout_int16[i] = inout_int16_for_check[i] = -3; + init_in_int16[i] = 5; + init_inout_int16[i] = -3; } mpi_type = "MPI_INT16_T"; if (0 == strcmp(op, "sum")) { - MPI_OP_TEST(+, mpi_op, MPI_INT16_T, int16_t, in_int16, inout_int16, + MPI_OP_TEST(+, mpi_op, MPI_INT16_T, int16_t, + init_in_int16, in_int16, + init_inout_int16, inout_int16, inout_int16_for_check, count, PRId16); } if (0 == strcmp(op, "bor")) { - MPI_OP_TEST(|, mpi_op, MPI_INT16_T, int16_t, in_int16, inout_int16, + MPI_OP_TEST(|, mpi_op, MPI_INT16_T, int16_t, + init_in_int16, in_int16, + init_inout_int16, inout_int16, inout_int16_for_check, count, PRId16); } if (0 == strcmp(op, "bxor")) { - MPI_OP_TEST(^, mpi_op, MPI_INT16_T, int16_t, in_int16, inout_int16, + MPI_OP_TEST(^, mpi_op, MPI_INT16_T, int16_t, + init_in_int16, in_int16, + init_inout_int16, inout_int16, inout_int16_for_check, count, PRId16); } if (0 == strcmp(op, "prod")) { - MPI_OP_TEST(*, mpi_op, MPI_INT16_T, int16_t, in_int16, inout_int16, + MPI_OP_TEST(*, mpi_op, MPI_INT16_T, int16_t, + init_in_int16, in_int16, + init_inout_int16, inout_int16, inout_int16_for_check, count, PRId16); } if (0 == strcmp(op, "band")) { - MPI_OP_TEST(&, mpi_op, MPI_INT16_T, int16_t, in_int16, inout_int16, + MPI_OP_TEST(&, mpi_op, MPI_INT16_T, int16_t, + init_in_int16, in_int16, + init_inout_int16, inout_int16, inout_int16_for_check, count, PRId16); } if (0 == strcmp(op, "max")) { - MPI_OP_MINMAX_TEST(max, mpi_op, MPI_INT16_T, int16_t, in_int16, - inout_int16, inout_int16_for_check, count, PRId16); + MPI_OP_MINMAX_TEST(max, mpi_op, MPI_INT16_T, int16_t, + init_in_int16, in_int16, + init_inout_int16, inout_int16, + inout_int16_for_check, count, PRId16); } if (0 == strcmp(op, "min")) { // intentionally reversed in and out - MPI_OP_MINMAX_TEST(min, mpi_op, MPI_INT16_T, int16_t, in_int16, - inout_int16, inout_int16_for_check, count, PRId16); + MPI_OP_MINMAX_TEST(min, mpi_op, MPI_INT16_T, int16_t, + init_in_int16, in_int16, + init_inout_int16, inout_int16, + inout_int16_for_check, count, PRId16); } } if (32 == type_size) { @@ -400,40 +607,56 @@ int main(int argc, char **argv) + op1_alignment * sizeof(int32_t)), *inout_int32 = (int32_t *) ((char *) inout_buf + res_alignment * sizeof(int32_t)), - *inout_int32_for_check = (int32_t *) inout_check_buf; + *inout_int32_for_check = (int32_t *) inout_check_buf, + *init_inout_int32 = (int32_t *)init_inout_buf, + *init_in_int32 = (int32_t *)init_in_buf; for (i = 0; i < count; i++) { - in_int32[i] = 5; - inout_int32[i] = inout_int32_for_check[i] = 3; + init_in_int32[i] = 5; + init_inout_int32[i] = inout_int32_for_check[i] = 3; } mpi_type = "MPI_INT32_T"; if (0 == strcmp(op, "sum")) { - MPI_OP_TEST(+, mpi_op, MPI_INT32_T, int32_t, in_int32, inout_int32, + MPI_OP_TEST(+, mpi_op, MPI_INT32_T, int32_t, + init_in_int32, in_int32, + init_inout_int32, inout_int32, inout_int32_for_check, count, PRId32); } if (0 == strcmp(op, "bor")) { - MPI_OP_TEST(|, mpi_op, MPI_INT32_T, int32_t, in_int32, inout_int32, + MPI_OP_TEST(|, mpi_op, MPI_INT32_T, int32_t, + init_in_int32, in_int32, + init_inout_int32, inout_int32, inout_int32_for_check, count, PRId32); } if (0 == strcmp(op, "bxor")) { - MPI_OP_TEST(^, mpi_op, MPI_INT32_T, int32_t, in_int32, inout_int32, + MPI_OP_TEST(^, mpi_op, MPI_INT32_T, int32_t, + init_in_int32, in_int32, + init_inout_int32, inout_int32, inout_int32_for_check, count, PRId32); } if (0 == strcmp(op, "prod")) { - MPI_OP_TEST(*, mpi_op, MPI_INT32_T, int32_t, in_int32, inout_int32, + MPI_OP_TEST(*, mpi_op, MPI_INT32_T, int32_t, + init_in_int32, in_int32, + init_inout_int32, inout_int32, inout_int32_for_check, count, PRId32); } if (0 == strcmp(op, "band")) { - MPI_OP_TEST(&, mpi_op, MPI_INT32_T, int32_t, in_int32, inout_int32, + MPI_OP_TEST(&, mpi_op, MPI_INT32_T, int32_t, + init_in_int32, in_int32, + init_inout_int32, inout_int32, inout_int32_for_check, count, PRId32); } if (0 == strcmp(op, "max")) { - MPI_OP_MINMAX_TEST(max, mpi_op, MPI_INT32_T, int32_t, in_int32, - inout_int32, inout_int32_for_check, count, PRId32); + MPI_OP_MINMAX_TEST(max, mpi_op, MPI_INT32_T, int32_t, + init_in_int32, in_int32, + init_inout_int32, inout_int32, + inout_int32_for_check, count, PRId32); } if (0 == strcmp(op, "min")) { // intentionally reversed in and out - MPI_OP_MINMAX_TEST(min, mpi_op, MPI_INT32_T, int32_t, in_int32, - inout_int32, inout_int32_for_check, count, PRId32); + MPI_OP_MINMAX_TEST(min, mpi_op, MPI_INT32_T, int32_t, + init_in_int32, in_int32, + init_inout_int32, inout_int32, + inout_int32_for_check, count, PRId32); } } if (64 == type_size) { @@ -441,40 +664,56 @@ int main(int argc, char **argv) + op1_alignment * sizeof(int64_t)), *inout_int64 = (int64_t *) ((char *) inout_buf + res_alignment * sizeof(int64_t)), - *inout_int64_for_check = (int64_t *) inout_check_buf; + *inout_int64_for_check = (int64_t *) inout_check_buf, + *init_inout_int64 = (int64_t *)init_inout_buf, + *init_in_int64 = (int64_t *)init_in_buf; for (i = 0; i < count; i++) { - in_int64[i] = 5; - inout_int64[i] = inout_int64_for_check[i] = 3; + init_in_int64[i] = 5; + init_inout_int64[i] = 3; } mpi_type = "MPI_INT64_T"; if (0 == strcmp(op, "sum")) { - MPI_OP_TEST(+, mpi_op, MPI_INT64_T, int64_t, in_int64, inout_int64, + MPI_OP_TEST(+, mpi_op, MPI_INT64_T, int64_t, + init_in_int64, in_int64, + init_inout_int64, inout_int64, inout_int64_for_check, count, PRId64); } if (0 == strcmp(op, "bor")) { - MPI_OP_TEST(|, mpi_op, MPI_INT64_T, int64_t, in_int64, inout_int64, + MPI_OP_TEST(|, mpi_op, MPI_INT64_T, int64_t, + init_in_int64, in_int64, + init_inout_int64, inout_int64, inout_int64_for_check, count, PRId64); } if (0 == strcmp(op, "bxor")) { - MPI_OP_TEST(^, mpi_op, MPI_INT64_T, int64_t, in_int64, inout_int64, + MPI_OP_TEST(^, mpi_op, MPI_INT64_T, int64_t, + init_in_int64, in_int64, + init_inout_int64, inout_int64, inout_int64_for_check, count, PRId64); } if (0 == strcmp(op, "prod")) { - MPI_OP_TEST(*, mpi_op, MPI_INT64_T, int64_t, in_int64, inout_int64, + MPI_OP_TEST(*, mpi_op, MPI_INT64_T, int64_t, + init_in_int64, in_int64, + init_inout_int64, inout_int64, inout_int64_for_check, count, PRId64); } if (0 == strcmp(op, "band")) { - MPI_OP_TEST(&, mpi_op, MPI_INT64_T, int64_t, in_int64, inout_int64, + MPI_OP_TEST(&, mpi_op, MPI_INT64_T, int64_t, + init_in_int64, in_int64, + init_inout_int64, inout_int64, inout_int64_for_check, count, PRId64); } if (0 == strcmp(op, "max")) { - MPI_OP_MINMAX_TEST(max, mpi_op, MPI_INT64_T, int64_t, in_int64, - inout_int64, inout_int64_for_check, count, PRId64); + MPI_OP_MINMAX_TEST(max, mpi_op, MPI_INT64_T, int64_t, + init_in_int64, in_int64, + init_inout_int64, inout_int64, + inout_int64_for_check, count, PRId64); } if (0 == strcmp(op, "min")) { // intentionally reversed in and out - MPI_OP_MINMAX_TEST(min, mpi_op, MPI_INT64_T, int64_t, in_int64, - inout_int64, inout_int64_for_check, count, PRId64); + MPI_OP_MINMAX_TEST(min, mpi_op, MPI_INT64_T, int64_t, + init_in_int64, in_int64, + init_inout_int64, inout_int64, + inout_int64_for_check, count, PRId64); } } } @@ -485,40 +724,56 @@ int main(int argc, char **argv) + op1_alignment * sizeof(uint8_t)), *inout_uint8 = (uint8_t *) ((char *) inout_buf + res_alignment * sizeof(uint8_t)), - *inout_uint8_for_check = (uint8_t *) inout_check_buf; + *inout_uint8_for_check = (uint8_t *) inout_check_buf, + *init_inout_uint8 = (uint8_t *)init_inout_buf, + *init_in_uint8 = (uint8_t *)init_in_buf; for (i = 0; i < count; i++) { - in_uint8[i] = 5; - inout_uint8[i] = inout_uint8_for_check[i] = 2; + init_in_uint8[i] = 5; + init_inout_uint8[i] = 2; } mpi_type = "MPI_UINT8_T"; if (0 == strcmp(op, "sum")) { - MPI_OP_TEST(+, mpi_op, MPI_UINT8_T, uint8_t, in_uint8, inout_uint8, + MPI_OP_TEST(+, mpi_op, MPI_UINT8_T, uint8_t, + init_in_uint8, in_uint8, + init_inout_uint8, inout_uint8, inout_uint8_for_check, count, PRIu8); } if (0 == strcmp(op, "bor")) { - MPI_OP_TEST(|, mpi_op, MPI_UINT8_T, uint8_t, in_uint8, inout_uint8, + MPI_OP_TEST(|, mpi_op, MPI_UINT8_T, uint8_t, + init_in_uint8, in_uint8, + init_inout_uint8, inout_uint8, inout_uint8_for_check, count, PRIu8); } if (0 == strcmp(op, "bxor")) { - MPI_OP_TEST(^, mpi_op, MPI_UINT8_T, uint8_t, in_uint8, inout_uint8, + MPI_OP_TEST(^, mpi_op, MPI_UINT8_T, uint8_t, + init_in_uint8, in_uint8, + init_inout_uint8, inout_uint8, inout_uint8_for_check, count, PRIu8); } if (0 == strcmp(op, "prod")) { - MPI_OP_TEST(*, mpi_op, MPI_UINT8_T, uint8_t, in_uint8, inout_uint8, + MPI_OP_TEST(*, mpi_op, MPI_UINT8_T, uint8_t, + init_in_uint8, in_uint8, + init_inout_uint8, inout_uint8, inout_uint8_for_check, count, PRIu8); } if (0 == strcmp(op, "band")) { - MPI_OP_TEST(&, mpi_op, MPI_UINT8_T, uint8_t, in_uint8, inout_uint8, + MPI_OP_TEST(&, mpi_op, MPI_UINT8_T, uint8_t, + init_in_uint8, in_uint8, + init_inout_uint8, inout_uint8, inout_uint8_for_check, count, PRIu8); } if (0 == strcmp(op, "max")) { - MPI_OP_MINMAX_TEST(max, mpi_op, MPI_UINT8_T, uint8_t, in_uint8, - inout_uint8, inout_uint8_for_check, count, PRIu8); + MPI_OP_MINMAX_TEST(max, mpi_op, MPI_UINT8_T, uint8_t, + init_in_uint8, in_uint8, + init_inout_uint8, inout_uint8, + inout_uint8_for_check, count, PRIu8); } if (0 == strcmp(op, "min")) { // intentionally reversed in and out - MPI_OP_MINMAX_TEST(min, mpi_op, MPI_UINT8_T, uint8_t, in_uint8, - inout_uint8, inout_uint8_for_check, count, PRIu8); + MPI_OP_MINMAX_TEST(min, mpi_op, MPI_UINT8_T, uint8_t, + init_in_uint8, in_uint8, + init_inout_uint8, inout_uint8, + inout_uint8_for_check, count, PRIu8); } } if (16 == type_size) { @@ -526,40 +781,56 @@ int main(int argc, char **argv) + op1_alignment * sizeof(uint16_t)), *inout_uint16 = (uint16_t *) ((char *) inout_buf + res_alignment * sizeof(uint16_t)), - *inout_uint16_for_check = (uint16_t *) inout_check_buf; + *inout_uint16_for_check = (uint16_t *) inout_check_buf, + *init_inout_uint16 = (uint16_t *)init_inout_buf, + *init_in_uint16 = (uint16_t *)init_in_buf; for (i = 0; i < count; i++) { - in_uint16[i] = 5; - inout_uint16[i] = inout_uint16_for_check[i] = 1234; + init_in_uint16[i] = 5; + init_inout_uint16[i] = 1234; } mpi_type = "MPI_UINT16_T"; if (0 == strcmp(op, "sum")) { - MPI_OP_TEST(+, mpi_op, MPI_UINT16_T, uint16_t, in_uint16, inout_uint16, + MPI_OP_TEST(+, mpi_op, MPI_UINT16_T, uint16_t, + init_in_uint16, in_uint16, + init_inout_uint16, inout_uint16, inout_uint16_for_check, count, PRIu16); } if (0 == strcmp(op, "bor")) { - MPI_OP_TEST(|, mpi_op, MPI_UINT16_T, uint16_t, in_uint16, inout_uint16, + MPI_OP_TEST(|, mpi_op, MPI_UINT16_T, uint16_t, + init_in_uint16, in_uint16, + init_inout_uint16, inout_uint16, inout_uint16_for_check, count, PRIu16); } if (0 == strcmp(op, "bxor")) { - MPI_OP_TEST(^, mpi_op, MPI_UINT16_T, uint16_t, in_uint16, inout_uint16, + MPI_OP_TEST(^, mpi_op, MPI_UINT16_T, uint16_t, + init_in_uint16, in_uint16, + init_inout_uint16, inout_uint16, inout_uint16_for_check, count, PRIu16); } if (0 == strcmp(op, "prod")) { - MPI_OP_TEST(*, mpi_op, MPI_UINT16_T, uint16_t, in_uint16, inout_uint16, + MPI_OP_TEST(*, mpi_op, MPI_UINT16_T, uint16_t, + init_in_uint16, in_uint16, + init_inout_uint16, inout_uint16, inout_uint16_for_check, count, PRIu16); } if (0 == strcmp(op, "band")) { - MPI_OP_TEST(&, mpi_op, MPI_UINT16_T, uint16_t, in_uint16, inout_uint16, + MPI_OP_TEST(&, mpi_op, MPI_UINT16_T, uint16_t, + init_in_uint16, in_uint16, + init_inout_uint16, inout_uint16, inout_uint16_for_check, count, PRIu16); } if (0 == strcmp(op, "max")) { - MPI_OP_MINMAX_TEST(max, mpi_op, MPI_UINT16_T, uint16_t, in_uint16, - inout_uint16, inout_uint16_for_check, count, PRIu16); + MPI_OP_MINMAX_TEST(max, mpi_op, MPI_UINT16_T, uint16_t, + init_in_uint16, in_uint16, + init_inout_uint16, inout_uint16, + inout_uint16_for_check, count, PRIu16); } if (0 == strcmp(op, "min")) { // intentionally reversed in and out - MPI_OP_MINMAX_TEST(min, mpi_op, MPI_UINT16_T, uint16_t, in_uint16, - inout_uint16, inout_uint16_for_check, count, PRIu16); + MPI_OP_MINMAX_TEST(min, mpi_op, MPI_UINT16_T, uint16_t, + init_in_uint16, in_uint16, + init_inout_uint16, inout_uint16, + inout_uint16_for_check, count, PRIu16); } } if (32 == type_size) { @@ -567,40 +838,56 @@ int main(int argc, char **argv) + op1_alignment * sizeof(uint32_t)), *inout_uint32 = (uint32_t *) ((char *) inout_buf + res_alignment * sizeof(uint32_t)), - *inout_uint32_for_check = (uint32_t *) inout_check_buf; + *inout_uint32_for_check = (uint32_t *) inout_check_buf, + *init_inout_uint32 = (uint32_t *)init_inout_buf, + *init_in_uint32 = (uint32_t *)init_in_buf; for (i = 0; i < count; i++) { - in_uint32[i] = 5; - inout_uint32[i] = inout_uint32_for_check[i] = 3; + init_in_uint32[i] = 5; + init_inout_uint32[i] = 3; } mpi_type = "MPI_UINT32_T"; if (0 == strcmp(op, "sum")) { - MPI_OP_TEST(+, mpi_op, MPI_UINT32_T, uint32_t, in_uint32, inout_uint32, + MPI_OP_TEST(+, mpi_op, MPI_UINT32_T, uint32_t, + init_in_uint32, in_uint32, + init_inout_uint32, inout_uint32, inout_uint32_for_check, count, PRIu32); } if (0 == strcmp(op, "bor")) { - MPI_OP_TEST(|, mpi_op, MPI_UINT32_T, uint32_t, in_uint32, inout_uint32, + MPI_OP_TEST(|, mpi_op, MPI_UINT32_T, uint32_t, + init_in_uint32, in_uint32, + init_inout_uint32, inout_uint32, inout_uint32_for_check, count, PRIu32); } if (0 == strcmp(op, "bxor")) { - MPI_OP_TEST(^, mpi_op, MPI_UINT32_T, uint32_t, in_uint32, inout_uint32, + MPI_OP_TEST(^, mpi_op, MPI_UINT32_T, uint32_t, + init_in_uint32, in_uint32, + init_inout_uint32, inout_uint32, inout_uint32_for_check, count, PRIu32); } if (0 == strcmp(op, "prod")) { - MPI_OP_TEST(*, mpi_op, MPI_UINT32_T, uint32_t, in_uint32, inout_uint32, + MPI_OP_TEST(*, mpi_op, MPI_UINT32_T, uint32_t, + init_in_uint32, in_uint32, + init_inout_uint32, inout_uint32, inout_uint32_for_check, count, PRIu32); } if (0 == strcmp(op, "band")) { - MPI_OP_TEST(&, mpi_op, MPI_UINT32_T, uint32_t, in_uint32, inout_uint32, + MPI_OP_TEST(&, mpi_op, MPI_UINT32_T, uint32_t, + init_in_uint32, in_uint32, + init_inout_uint32, inout_uint32, inout_uint32_for_check, count, PRIu32); } if (0 == strcmp(op, "max")) { - MPI_OP_MINMAX_TEST(max, mpi_op, MPI_UINT32_T, uint32_t, in_uint32, - inout_uint32, inout_uint32_for_check, count, PRIu32); + MPI_OP_MINMAX_TEST(max, mpi_op, MPI_UINT32_T, uint32_t, + init_in_uint32, in_uint32, + init_inout_uint32, inout_uint32, + inout_uint32_for_check, count, PRIu32); } if (0 == strcmp(op, "min")) { // intentionally reversed in and out - MPI_OP_MINMAX_TEST(min, mpi_op, MPI_UINT32_T, uint32_t, in_uint32, - inout_uint32, inout_uint32_for_check, count, PRIu32); + MPI_OP_MINMAX_TEST(min, mpi_op, MPI_UINT32_T, uint32_t, + init_in_uint32, in_uint32, + init_inout_uint32, inout_uint32, + inout_uint32_for_check, count, PRIu32); } } if (64 == type_size) { @@ -608,40 +895,56 @@ int main(int argc, char **argv) + op1_alignment * sizeof(uint64_t)), *inout_uint64 = (uint64_t *) ((char *) inout_buf + res_alignment * sizeof(uint64_t)), - *inout_uint64_for_check = (uint64_t *) inout_check_buf; + *inout_uint64_for_check = (uint64_t *) inout_check_buf, + *init_inout_uint64 = (uint64_t *)init_inout_buf, + *init_in_uint64 = (uint64_t *)init_in_buf; for (i = 0; i < count; i++) { - in_uint64[i] = 5; - inout_uint64[i] = inout_uint64_for_check[i] = 32433; + init_in_uint64[i] = 5; + init_inout_uint64[i] = 32433; } mpi_type = "MPI_UINT64_T"; if (0 == strcmp(op, "sum")) { - MPI_OP_TEST(+, mpi_op, MPI_UINT64_T, uint64_t, in_uint64, inout_uint64, + MPI_OP_TEST(+, mpi_op, MPI_UINT64_T, uint64_t, + init_in_uint64, in_uint64, + init_inout_uint64, inout_uint64, inout_uint64_for_check, count, PRIu64); } if (0 == strcmp(op, "bor")) { - MPI_OP_TEST(|, mpi_op, MPI_UINT64_T, uint64_t, in_uint64, inout_uint64, + MPI_OP_TEST(|, mpi_op, MPI_UINT64_T, uint64_t, + init_in_uint64, in_uint64, + init_inout_uint64, inout_uint64, inout_uint64_for_check, count, PRIu64); } if (0 == strcmp(op, "bxor")) { - MPI_OP_TEST(^, mpi_op, MPI_UINT64_T, uint64_t, in_uint64, inout_uint64, + MPI_OP_TEST(^, mpi_op, MPI_UINT64_T, uint64_t, + init_in_uint64, in_uint64, + init_inout_uint64, inout_uint64, inout_uint64_for_check, count, PRIu64); } if (0 == strcmp(op, "prod")) { - MPI_OP_TEST(*, mpi_op, MPI_UINT64_T, uint64_t, in_uint64, inout_uint64, + MPI_OP_TEST(*, mpi_op, MPI_UINT64_T, uint64_t, + init_in_uint64, in_uint64, + init_inout_uint64, inout_uint64, inout_uint64_for_check, count, PRIu64); } if (0 == strcmp(op, "band")) { - MPI_OP_TEST(&, mpi_op, MPI_UINT64_T, uint64_t, in_uint64, inout_uint64, + MPI_OP_TEST(&, mpi_op, MPI_UINT64_T, uint64_t, + init_in_uint64, in_uint64, + init_inout_uint64, inout_uint64, inout_uint64_for_check, count, PRIu64); } if (0 == strcmp(op, "max")) { - MPI_OP_MINMAX_TEST(max, mpi_op, MPI_UINT64_T, uint64_t, in_uint64, - inout_uint64, inout_uint64_for_check, count, PRIu64); + MPI_OP_MINMAX_TEST(max, mpi_op, MPI_UINT64_T, uint64_t, + init_in_uint64, in_uint64, + init_inout_uint64, inout_uint64, + inout_uint64_for_check, count, PRIu64); } if (0 == strcmp(op, "min")) { - MPI_OP_MINMAX_TEST(min, mpi_op, MPI_UINT64_T, uint64_t, in_uint64, - inout_uint64, inout_uint64_for_check, count, PRIu64); + MPI_OP_MINMAX_TEST(min, mpi_op, MPI_UINT64_T, uint64_t, + init_in_uint64, in_uint64, + init_inout_uint64, inout_uint64, + inout_uint64_for_check, count, PRIu64); } } } @@ -650,27 +953,37 @@ int main(int argc, char **argv) float *in_float = (float *) ((char *) in_buf + op1_alignment * sizeof(float)), *inout_float = (float *) ((char *) inout_buf + res_alignment * sizeof(float)), - *inout_float_for_check = (float *) inout_check_buf; + *inout_float_for_check = (float *) inout_check_buf, + *init_inout_float = (float *)init_inout_buf, + *init_in_float = (float *)init_in_buf; for (i = 0; i < count; i++) { - in_float[i] = 1000.0 + 1; - inout_float[i] = inout_float_for_check[i] = 100.0 + 2; + init_in_float[i] = 1000.0 + 1; + init_inout_float[i] = 100.0 + 2; } mpi_type = "MPI_FLOAT"; if (0 == strcmp(op, "sum")) { - MPI_OP_TEST(+, mpi_op, MPI_FLOAT, float, in_float, inout_float, + MPI_OP_TEST(+, mpi_op, MPI_FLOAT, float, + init_in_float, in_float, + init_inout_float, inout_float, inout_float_for_check, count, "f"); } if (0 == strcmp(op, "prod")) { - MPI_OP_TEST(*, mpi_op, MPI_FLOAT, float, in_float, inout_float, + MPI_OP_TEST(*, mpi_op, MPI_FLOAT, float, + init_in_float, in_float, + init_inout_float, inout_float, inout_float_for_check, count, "f"); } if (0 == strcmp(op, "max")) { - MPI_OP_MINMAX_TEST(max, mpi_op, MPI_FLOAT, float, in_float, inout_float, + MPI_OP_MINMAX_TEST(max, mpi_op, MPI_FLOAT, float, + init_in_float, in_float, + init_inout_float, inout_float, inout_float_for_check, count, "f"); } if (0 == strcmp(op, "min")) { - MPI_OP_MINMAX_TEST(min, mpi_op, MPI_FLOAT, float, in_float, inout_float, + MPI_OP_MINMAX_TEST(min, mpi_op, MPI_FLOAT, float, + init_in_float, in_float, + init_inout_float, inout_float, inout_float_for_check, count, "f"); } } @@ -680,27 +993,37 @@ int main(int argc, char **argv) + op1_alignment * sizeof(double)), *inout_double = (double *) ((char *) inout_buf + res_alignment * sizeof(double)), - *inout_double_for_check = (double *) inout_check_buf; + *inout_double_for_check = (double *) inout_check_buf, + *init_inout_double = (double *)init_inout_buf, + *init_in_double = (double *)init_in_buf; for (i = 0; i < count; i++) { - in_double[i] = 10.0 + 1; - inout_double[i] = inout_double_for_check[i] = 1.0 + 2; + init_in_double[i] = 10.0 + 1; + init_inout_double[i] = 1.0 + 2; } mpi_type = "MPI_DOUBLE"; if (0 == strcmp(op, "sum")) { - MPI_OP_TEST(+, mpi_op, MPI_DOUBLE, double, in_double, inout_double, - inout_double_for_check, count, "g"); + MPI_OP_TEST(+, mpi_op, MPI_DOUBLE, double, + init_in_double, in_double, + init_inout_double, inout_double, + inout_double_for_check, count, "f"); } if (0 == strcmp(op, "prod")) { - MPI_OP_TEST(*, mpi_op, MPI_DOUBLE, double, in_double, inout_double, + MPI_OP_TEST(*, mpi_op, MPI_DOUBLE, double, + init_in_double, in_double, + init_inout_double, inout_double, inout_double_for_check, count, "f"); } if (0 == strcmp(op, "max")) { - MPI_OP_MINMAX_TEST(max, mpi_op, MPI_DOUBLE, double, in_double, inout_double, + MPI_OP_MINMAX_TEST(max, mpi_op, MPI_DOUBLE, double, + init_in_double, in_double, + init_inout_double, inout_double, inout_double_for_check, count, "f"); } if (0 == strcmp(op, "min")) { - MPI_OP_MINMAX_TEST(min, mpi_op, MPI_DOUBLE, double, in_double, inout_double, + MPI_OP_MINMAX_TEST(min, mpi_op, MPI_DOUBLE, double, + init_in_double, in_double, + init_inout_double, inout_double, inout_double_for_check, count, "f"); } } @@ -713,11 +1036,19 @@ int main(int argc, char **argv) printf("\n"); } } - ompi_mpi_finalize(); - free(in_buf); - free(inout_buf); - free(inout_check_buf); + /* clean up allocator */ + allocator->free(in_buf); + allocator->free(inout_buf); + allocator->free(inout_check_buf); + allocator->fini(); + + if (allocator->flags & ALLOCATOR_DISCRETE) { + free(init_in_buf); + free(init_inout_buf); + } + + ompi_mpi_finalize(); return (0 == total_errors) ? 0 : -1; }