@@ -1458,11 +1458,6 @@ static void ggml_backend_cann_free(ggml_backend_t backend) {
1458
1458
ACL_CHECK (aclrtSynchronizeDevice ());
1459
1459
ACL_CHECK (aclrtResetDevice (cann_ctx->device ));
1460
1460
1461
- // finalize when last backend freed.
1462
- if (cann_ctx->device == ggml_backend_cann_get_device_count () - 1 ) {
1463
- ACL_CHECK (aclFinalize ());
1464
- }
1465
-
1466
1461
delete cann_ctx;
1467
1462
delete backend;
1468
1463
}
@@ -1688,11 +1683,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1688
1683
}
1689
1684
case GGML_OP_MUL_MAT: {
1690
1685
switch (op->src [0 ]->type ) {
1691
- case GGML_TYPE_Q8_0:
1692
1686
case GGML_TYPE_F16:
1693
1687
case GGML_TYPE_F32:
1694
- case GGML_TYPE_Q4_0:
1695
1688
return true ;
1689
+ case GGML_TYPE_Q8_0:
1690
+ case GGML_TYPE_Q4_0:
1691
+ // only support contiguous for quantized types.
1692
+ return ggml_is_contiguous (op->src [0 ]) &&
1693
+ ggml_is_contiguous (op->src [1 ]);
1696
1694
default :
1697
1695
return false ;
1698
1696
}
@@ -1738,13 +1736,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1738
1736
}
1739
1737
case GGML_OP_ROPE: {
1740
1738
// TODO: with ops-test v == 1
1741
- float * ext_factor = (float *)((int32_t *)op->op_params + 7 );
1739
+ float ext_factor = 0 .0f ;
1740
+ memcpy (&ext_factor, (const float *) op->op_params + 7 , sizeof (float ));
1742
1741
// TODO: n_dims <= ne0
1743
1742
if (op->src [0 ]->ne [0 ] != op->op_params [1 ]) {
1744
1743
return false ;
1745
1744
}
1746
1745
// TODO: ext_factor != 0
1747
- if (* ext_factor != 0 ) {
1746
+ if (ext_factor != 0 ) {
1748
1747
return false ;
1749
1748
}
1750
1749
@@ -1766,6 +1765,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1766
1765
}
1767
1766
return true ;
1768
1767
}
1768
+ case GGML_OP_POOL_2D: {
1769
+ const int32_t * opts = (const int32_t *) op->op_params ;
1770
+ const int k0 = opts[1 ];
1771
+ const int k1 = opts[2 ];
1772
+ const int p0 = opts[5 ];
1773
+ const int p1 = opts[6 ];
1774
+ // value of paddingH should be at most half of kernelH
1775
+ // value of paddingW should be at most half of kernelW
1776
+ return (p0 <= (k0 / 2 )) && (p1 <= (k1 / 2 ));
1777
+ }
1769
1778
case GGML_OP_DUP:
1770
1779
case GGML_OP_IM2COL:
1771
1780
case GGML_OP_CONCAT:
@@ -1785,7 +1794,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1785
1794
case GGML_OP_CLAMP:
1786
1795
case GGML_OP_DIAG_MASK_INF:
1787
1796
case GGML_OP_SOFT_MAX:
1788
- case GGML_OP_POOL_2D:
1789
1797
case GGML_OP_SUM_ROWS:
1790
1798
case GGML_OP_ARGSORT:
1791
1799
case GGML_OP_ACC:
0 commit comments