diff --git a/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py b/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py index 84c4aa2ac..a93297ca5 100644 --- a/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py +++ b/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py @@ -905,8 +905,7 @@ def functions_code_gen(fun_config): ) fbody += custom_autograd_function_code fun_name = wrapper_fun_name - - if fun_config.get("autocompare", False) in [True, "True"] and fun_config.get( + if fun_config.get("autocompare") not in ["disable"] and fun_config.get( "register_op", True ) in [True, "True"]: auto_compare_fun_name = fun_name + "_autocompare" @@ -929,17 +928,16 @@ def functions_code_gen(fun_config): ).replace(raw_fun_name, fun_name) ], transform_result_to_cpu_code=[], - result_compare_code=[ - create_result_compare_code(fun_config) - + ( - "\nreturn result_device;\n" - if len(get_function_return_param_from_schema(fun_config["schema"])) - > 0 - else "" - ) - ], + result_compare_code=[create_result_compare_code(fun_config)], ) fbody += autocompare_code + last_brace_pos = fbody.rfind("}") + fbody = fbody[:last_brace_pos] + fbody += ( + "\n\treturn result_device;\n}" + if len(get_function_return_param_from_schema(fun_config["schema"])) > 0 + else "\n\t}" + ) fun_name = auto_compare_fun_name if fun_config.get("custom_fallback", False) in ["False", False]: @@ -1039,12 +1037,7 @@ def parse_args(): type=boolean_string, help="whether generate code that prints op args", ) - parser.add_argument( - "--autocompare", - default=False, - type=boolean_string, - help="whether generate code that compare device calculation results with cpu calculation results", - ) + parser.add_argument( "--fun_config_dict", type=json.loads, diff --git a/dipu/scripts/autogen_diopi_wrapper/autogen_wrapped_code.sh b/dipu/scripts/autogen_diopi_wrapper/autogen_wrapped_code.sh index 3d6e0dd18..fd6e01b11 100755 --- a/dipu/scripts/autogen_diopi_wrapper/autogen_wrapped_code.sh +++ b/dipu/scripts/autogen_diopi_wrapper/autogen_wrapped_code.sh @@ -5,17 +5,16 @@ DIPU_DIR=$(readlink -f $(dirname $(readlink -f "$0"))/../..) AUTOGEN_DIOPI_WRAPPER=$DIPU_DIR/scripts/autogen_diopi_wrapper -USE_AUTOCOMPARE=${1:-OFF} -UsedVendor=${2:-cuda} -Torch_VERSION=${3:-2.1.0} -GENERATED_KERNELS_SCRIPT=${4:-$AUTOGEN_DIOPI_WRAPPER/autogen_diopi_wrapper.py} -GENERATED_KERNELS_CONFIG=${5:-$AUTOGEN_DIOPI_WRAPPER/diopi_functions.yaml} -GENERATED_KERNELS=${6:-$DIPU_DIR/torch_dipu/csrc_dipu/aten/ops/AutoGenedKernels.cpp} +UsedVendor=${1:-cuda} +Torch_VERSION=${2:-2.1.0} +GENERATED_KERNELS_SCRIPT=${3:-$AUTOGEN_DIOPI_WRAPPER/autogen_diopi_wrapper.py} +GENERATED_KERNELS_CONFIG=${4:-$AUTOGEN_DIOPI_WRAPPER/diopi_functions.yaml} +GENERATED_KERNELS=${5:-$DIPU_DIR/torch_dipu/csrc_dipu/aten/ops/AutoGenedKernels.cpp} GENERATED_KERNELS_VENDOR=${DIPU_DIR}/third_party/DIOPI/impl/${UsedVendor}/convert_config.yaml PYTHON_CMD="python3 ${GENERATED_KERNELS_SCRIPT} --out=${GENERATED_KERNELS} --config=${GENERATED_KERNELS_CONFIG} \ - --autocompare=${USE_AUTOCOMPARE} --print_op_arg=True --use_diopi_adapter=False --print_func_call_info=True \ + --print_op_arg=True --use_diopi_adapter=False --print_func_call_info=True \ --fun_config_dict='{\"current_device\":\"${UsedVendor}\",\"current_torch_ver\":\"${Torch_VERSION}\"}'" if [ -f "$GENERATED_KERNELS_VENDOR" ]; then diff --git a/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py b/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py index ba723da1b..a073a2db6 100644 --- a/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py +++ b/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py @@ -171,12 +171,14 @@ class $autograd_function_name : public torch::autograd::Function<$autograd_funct std::cout << std::endl << __FUNCTION__ << std::endl; $transform_input_to_cpu_code - $execute_op_on_cpu_code - $execute_op_on_device_code + + if (useAutoCompare()) { + $execute_op_on_cpu_code - $transform_result_to_cpu_code + $transform_result_to_cpu_code - $result_compare_code + $result_compare_code + } } """ diff --git a/dipu/torch_dipu/csrc_dipu/CMakeLists.txt b/dipu/torch_dipu/csrc_dipu/CMakeLists.txt index 20bb442fe..7b3ebaa18 100644 --- a/dipu/torch_dipu/csrc_dipu/CMakeLists.txt +++ b/dipu/torch_dipu/csrc_dipu/CMakeLists.txt @@ -1,5 +1,4 @@ #[[ Dependencies ]] -option(USE_AUTOCOMPARE "whether to use USE_AUTOCOMPARE" OFF) # Import Python3::Python, Python3_EXECUTABLE # Also see https://cmake.org/cmake/help/latest/module/FindPython3.html @@ -58,7 +57,7 @@ endif() add_custom_command( OUTPUT "${GENERATED_KERNELS}" - COMMAND bash -c "${AUTOGEN_CODE_SH} ${USE_AUTOCOMPARE} ${UsedVendor} ${Torch_VERSION} ${GENERATED_KERNELS_SCRIPT} ${GENERATED_KERNELS_CONFIG} ${GENERATED_KERNELS}" + COMMAND bash -c "${AUTOGEN_CODE_SH} ${UsedVendor} ${Torch_VERSION} ${GENERATED_KERNELS_SCRIPT} ${GENERATED_KERNELS_CONFIG} ${GENERATED_KERNELS}" COMMENT "Generating ${GENERATED_KERNELS}$<$: with ${GENERATED_KERNELS_VENDOR}>" DEPENDS "${GENERATED_KERNELS_SCRIPT}" diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/OpUtils.hpp b/dipu/torch_dipu/csrc_dipu/aten/ops/OpUtils.hpp index dc9d9fcb3..80b7bfcbe 100644 --- a/dipu/torch_dipu/csrc_dipu/aten/ops/OpUtils.hpp +++ b/dipu/torch_dipu/csrc_dipu/aten/ops/OpUtils.hpp @@ -85,6 +85,28 @@ inline int dumpOpArgLevel() { return level; } +inline bool useAutoCompare() { + static const char* autocomparePtr = std::getenv("USE_AUTOCOMPARE"); + if (autocomparePtr == nullptr) { + return false; + } + + std::string autocompareString(autocomparePtr); + for (char& c : autocompareString) { + c = static_cast(std::tolower(static_cast(c))); + } + + if (autocompareString == "on") { + return true; + } + if (autocompareString == "off") { + return false; + } + + std::cerr << "Error: USE_AUTOCOMPARE can only be set to 'ON' or 'OFF'.\n"; + return false; +} + template std::string dumpArg(const T& t) { std::stringstream stream;