forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathHalf.cpp
105 lines (84 loc) · 2.3 KB
/
Half.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include <ATen/core/Half.h>
#include <iostream>
namespace at {
static_assert(
std::is_standard_layout<Half>::value,
"at::Half must be standard layout.");
namespace detail {
// Host functions for converting between FP32 and FP16 formats
float halfbits2float(unsigned short h) {
unsigned sign = ((h >> 15) & 1);
unsigned exponent = ((h >> 10) & 0x1f);
unsigned mantissa = ((h & 0x3ff) << 13);
if (exponent == 0x1f) { /* NaN or Inf */
mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
exponent = 0xff;
} else if (!exponent) { /* Denorm or Zero */
if (mantissa) {
unsigned int msb;
exponent = 0x71;
do {
msb = (mantissa & 0x400000);
mantissa <<= 1; /* normalize */
--exponent;
} while (!msb);
mantissa &= 0x7fffff; /* 1.mantissa is implicit */
}
} else {
exponent += 0x70;
}
unsigned result_bit = (sign << 31) | (exponent << 23) | mantissa;
// Reinterpret the result bit pattern as a float
float result_float;
std::memcpy(&result_float, &result_bit, sizeof(result_float));
return result_float;
}
unsigned short float2halfbits(float src) {
// Reinterpret the float as a bit pattern
unsigned x;
std::memcpy(&x, &src, sizeof(x));
unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
unsigned sign, exponent, mantissa;
// Get rid of +NaN/-NaN case first.
if (u > 0x7f800000) {
return 0x7fffU;
}
sign = ((x >> 16) & 0x8000);
// Get rid of +Inf/-Inf, +0/-0.
if (u > 0x477fefff) {
return sign | 0x7c00U;
}
if (u < 0x33000001) {
return (sign | 0x0000);
}
exponent = ((u >> 23) & 0xff);
mantissa = (u & 0x7fffff);
if (exponent > 0x70) {
shift = 13;
exponent -= 0x70;
} else {
shift = 0x7e - exponent;
exponent = 0;
mantissa |= 0x800000;
}
lsb = (1 << shift);
lsb_s1 = (lsb >> 1);
lsb_m1 = (lsb - 1);
// Round to nearest even.
remainder = (mantissa & lsb_m1);
mantissa >>= shift;
if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
++mantissa;
if (!(mantissa & 0x3ff)) {
++exponent;
mantissa = 0;
}
}
return (sign | (exponent << 10) | mantissa);
}
} // namespace detail
std::ostream& operator<<(std::ostream& out, const Half& value) {
out << (float)value;
return out;
}
} // namespace at