Skip to content

Implement Hinge Embedding Loss #3623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
Open

Conversation

littlecutebird
Copy link
Collaborator

@littlecutebird littlecutebird commented Mar 14, 2025

  • Add HingeEmbeddingLoss API wrapped by MIOPEN_BETA_API, solver, kernel, driver and gtest for fp32, fp16 and bfp16

  • Performance speedup compared to ROCm/pytorch:

type Forward
float32 2.24

(Note that fp16 is ignored because using this op in fp16 will cause overflow, underflow output. Bfp16 is ignored because pytorch rocm doesn't support HingeEmbeddingLoss in bfp16)

fp32
Size Cont Reduction Margin Forward ROCm MIOpen Improvement
[10 100 100 100] TRUE none 0.4 fwd 700242 226738 3.088331025
[10 100 100 100] TRUE none 0.4 bwd 463201 301511 1.536265675
[10 100 100 100] TRUE sum 0.4 fwd 756242 297334 2.543409095
[10 100 100 100] TRUE sum 0.4 bwd 458242 231413 1.98019126
[10 100 100 100] TRUE mean 0.4 fwd 756322 311503 2.427976617
[10 100 100 100] TRUE mean 0.4 bwd 519442 260587 1.993353467
[32 32 16 64 32] TRUE none 0.4 fwd 2262166 750561 3.013966886
[32 32 16 64 32] TRUE none 0.4 bwd 1457844 999503 1.458568909
[32 32 16 64 32] TRUE sum 0.4 fwd 2373046 948178 2.502743156
[32 32 16 64 32] TRUE sum 0.4 bwd 1435444 764250 1.878238796
[32 32 16 64 32] TRUE mean 0.4 fwd 2371846 995254 2.383156461
[32 32 16 64 32] TRUE mean 0.4 bwd 1637844 861903 1.900264879
[128 256 8 256] TRUE none 0.4 fwd 4502972 1497850 3.00629035
[128 256 8 256] TRUE none 0.4 bwd 2887608 1995250 1.447241198
[128 256 8 256] TRUE sum 0.4 fwd 4708332 1874560 2.511699812
[128 256 8 256] TRUE sum 0.4 bwd 2826648 1526630 1.851560627
[128 256 8 256] TRUE mean 0.4 fwd 4702252 1969170 2.387936034
[128 256 8 256] TRUE mean 0.4 bwd 3228249 1722420 1.874251925
[20 50 100 100] TRUE none 0.4 fwd 698082 226542 3.081468337
[20 50 100 100] TRUE none 0.4 bwd 464722 302063 1.538493626
[20 50 100 100] TRUE sum 0.4 fwd 753842 296534 2.542177288
[20 50 100 100] TRUE sum 0.4 bwd 455762 231058 1.972500411
[20 50 100 100] TRUE mean 0.4 fwd 752082 311609 2.413543896
[20 50 100 100] TRUE mean 0.4 bwd 516721 260427 1.98412991
[2048 3096] TRUE none 0.4 fwd 452961 145440 3.114418317
[2048 3096] TRUE none 0.4 bwd 308241 193476 1.593174347
[2048 3096] TRUE sum 0.4 fwd 494082 197494 2.501757015
[2048 3096] TRUE sum 0.4 bwd 307200 149120 2.060085837
[2048 3096] TRUE mean 0.4 fwd 496162 209476 2.368586377
[2048 3096] TRUE mean 0.4 bwd 350241 167111 2.095858441
[2048 4096] TRUE none 0.4 fwd 583281 190596 3.060300321
[2048 4096] TRUE none 0.4 bwd 391281 254329 1.538483618
[2048 4096] TRUE sum 0.4 fwd 634882 252676 2.512632779
[2048 4096] TRUE sum 0.4 bwd 385361 194507 1.981219185
[2048 4096] TRUE mean 0.4 fwd 632002 265120 2.383833736
[2048 4096] TRUE mean 0.4 bwd 438081 219751 1.99353359
[256 512 512] TRUE none 0.4 fwd 4498412 1498030 3.002885122
[256 512 512] TRUE none 0.4 bwd 2886328 1998040 1.444579688
[256 512 512] TRUE sum 0.4 fwd 4703132 1876320 2.50657244
[256 512 512] TRUE sum 0.4 bwd 2828648 1525570 1.854158118
[256 512 512] TRUE mean 0.4 fwd 4709052 1970400 2.389896468
[256 512 512] TRUE mean 0.4 bwd 3223449 1721140 1.872856944
[25 1000 1000] TRUE none 0.4 fwd 1691124 560765 3.015744563
[25 1000 1000] TRUE none 0.4 bwd 1089603 747379 1.457898871
[25 1000 1000] TRUE sum 0.4 fwd 1779364 713956 2.492260027
[25 1000 1000] TRUE sum 0.4 bwd 1071043 570561 1.877175271
[25 1000 1000] TRUE mean 0.4 fwd 1776005 750561 2.366236722
[25 1000 1000] TRUE mean 0.4 bwd 1225283 644676 1.900618295
[2525 3333] TRUE none 0.4 fwd 586642 191591 3.061949674
[2525 3333] TRUE none 0.4 bwd 395041 255467 1.546348452
[2525 3333] TRUE sum 0.4 fwd 639682 254417 2.514305255
[2525 3333] TRUE sum 0.4 bwd 389281 195840 1.987750204
[2525 3333] TRUE mean 0.4 fwd 637602 268391 2.375645979
[2525 3333] TRUE mean 0.4 bwd 443681 220925 2.008287881
[555 111 111] TRUE none 0.4 fwd 505681 156551 3.230135866
[555 111 111] TRUE none 0.4 bwd 334881 208018 1.609865492
[555 111 111] TRUE sum 0.4 fwd 550482 210880 2.610404021
[555 111 111] TRUE sum 0.4 bwd 329521 159769 2.062483961
[555 111 111] TRUE mean 0.4 fwd 550002 221654 2.481353822
[555 111 111] TRUE mean 0.4 bwd 377441 179627 2.101248699
  • Fix bug in SoftMarginLoss backward in case of reduction != None: the output gradient tensor shape is set to be similar as input tensor shape. It should be set to be a scalar

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant