aboutsummaryrefslogtreecommitdiff
path: root/platform/linux-generic/odp_ml_fp16.c
blob: 47b10f841ecfc67b9956162126d1f7aded2898e3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
/* SPDX-License-Identifier: BSD-3-Clause
 * Copyright (c) 2022-2023 Marvell.
 * Copyright (c) 2023 Nokia
 *
 * Based on
 * - dpdk/lib/mldev/mldev_utils_scalar.h
 * - dpdk/lib/mldev/mldev_utils_scalar.c
 * - dpdk/lib/mldev/mldev_utils_scalar_bfloat16.c
 */

#include <odp_ml_fp16.h>

#include <errno.h>
#include <stdint.h>

#ifndef BIT
#define BIT(nr) (1UL << (nr))
#endif

#ifndef BITS_PER_LONG
#define BITS_PER_LONG (__SIZEOF_LONG__ * 8)
#endif

#ifndef GENMASK_U32
#define GENMASK_U32(h, l) (((~0UL) << (l)) & (~0UL >> (BITS_PER_LONG - 1 - (h))))
#endif

/* float32: bit index of MSB & LSB of sign, exponent and mantissa */
#define FP32_LSB_M 0
#define FP32_MSB_M 22
#define FP32_LSB_E 23
#define FP32_MSB_E 30
#define FP32_LSB_S 31
#define FP32_MSB_S 31

/* float32: bitmask for sign, exponent and mantissa */
#define FP32_MASK_S GENMASK_U32(FP32_MSB_S, FP32_LSB_S)
#define FP32_MASK_E GENMASK_U32(FP32_MSB_E, FP32_LSB_E)
#define FP32_MASK_M GENMASK_U32(FP32_MSB_M, FP32_LSB_M)

/* float16: bit index of MSB & LSB of sign, exponent and mantissa */
#define FP16_LSB_M 0
#define FP16_MSB_M 9
#define FP16_LSB_E 10
#define FP16_MSB_E 14
#define FP16_LSB_S 15
#define FP16_MSB_S 15

/* float16: bitmask for sign, exponent and mantissa */
#define FP16_MASK_S GENMASK_U32(FP16_MSB_S, FP16_LSB_S)
#define FP16_MASK_E GENMASK_U32(FP16_MSB_E, FP16_LSB_E)
#define FP16_MASK_M GENMASK_U32(FP16_MSB_M, FP16_LSB_M)

/* bfloat16: bit index of MSB & LSB of sign, exponent and mantissa */
#define BF16_LSB_M 0
#define BF16_MSB_M 6
#define BF16_LSB_E 7
#define BF16_MSB_E 14
#define BF16_LSB_S 15
#define BF16_MSB_S 15

/* bfloat16: bitmask for sign, exponent and mantissa */
#define BF16_MASK_S GENMASK_U32(BF16_MSB_S, BF16_LSB_S)
#define BF16_MASK_E GENMASK_U32(BF16_MSB_E, BF16_LSB_E)
#define BF16_MASK_M GENMASK_U32(BF16_MSB_M, BF16_LSB_M)

/* Exponent bias */
#define FP32_BIAS_E 127
#define FP16_BIAS_E 15
#define BF16_BIAS_E 127

#define FP32_PACK(sign, exponent, mantissa)                                                        \
	(((sign) << FP32_LSB_S) | ((exponent) << FP32_LSB_E) | (mantissa))

#define FP16_PACK(sign, exponent, mantissa)                                                        \
	(((sign) << FP16_LSB_S) | ((exponent) << FP16_LSB_E) | (mantissa))

#define BF16_PACK(sign, exponent, mantissa)                                                        \
	(((sign) << BF16_LSB_S) | ((exponent) << BF16_LSB_E) | (mantissa))

/* Represent float32 as float and uint32_t */
union float32 {
	float f;
	uint32_t u;
};

/* Convert a single precision floating point number (float32) into a half precision
 * floating point number (float16) using round to nearest rounding mode.
 */
static uint16_t
__float32_to_float16_scalar_rtn(float x)
{
	union float32 f32; /* float32 input */
	uint32_t f32_s;	   /* float32 sign */
	uint32_t f32_e;	   /* float32 exponent */
	uint32_t f32_m;	   /* float32 mantissa */
	uint16_t f16_s;	   /* float16 sign */
	uint16_t f16_e;	   /* float16 exponent */
	uint16_t f16_m;	   /* float16 mantissa */
	uint32_t tbits;	   /* number of truncated bits */
	uint32_t tmsb;	   /* MSB position of truncated bits */
	uint32_t m_32;	   /* temporary float32 mantissa */
	uint16_t m_16;	   /* temporary float16 mantissa */
	uint16_t u16;	   /* float16 output */
	int be_16;	   /* float16 biased exponent, signed */

	f32.f = x;
	f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S;
	f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E;
	f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M;

	f16_s = f32_s;
	f16_e = 0;
	f16_m = 0;

	switch (f32_e) {
	case (0): /* float32: zero or subnormal number */
		f16_e = 0;
		f16_m = 0; /* convert to zero */
		break;
	case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */
		f16_e = FP16_MASK_E >> FP16_LSB_E;
		if (f32_m == 0) { /* infinity */
			f16_m = 0;
		} else { /* nan, propagate mantissa and set MSB of mantissa to 1 */
			f16_m = f32_m >> (FP32_MSB_M - FP16_MSB_M);
			f16_m |= BIT(FP16_MSB_M);
		}
		break;
	default: /* float32: normal number */
		/* compute biased exponent for float16 */
		be_16 = (int)f32_e - FP32_BIAS_E + FP16_BIAS_E;

		/* overflow, be_16 = [31-INF], set to infinity */
		if (be_16 >= (int)(FP16_MASK_E >> FP16_LSB_E)) {
			f16_e = FP16_MASK_E >> FP16_LSB_E;
			f16_m = 0;
		} else if ((be_16 >= 1) && (be_16 < (int)(FP16_MASK_E >> FP16_LSB_E))) {
			/* normal float16, be_16 = [1:30]*/
			f16_e = be_16;
			m_16 = f32_m >> (FP32_LSB_E - FP16_LSB_E);
			tmsb = FP32_MSB_M - FP16_MSB_M - 1;
			if ((f32_m & GENMASK_U32(tmsb, 0)) > BIT(tmsb)) {
				/* round: non-zero truncated bits except MSB */
				m_16++;

				/* overflow into exponent */
				if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
					f16_e++;
			} else if ((f32_m & GENMASK_U32(tmsb, 0)) == BIT(tmsb)) {
				/* round: MSB of truncated bits and LSB of m_16 is set */
				if ((m_16 & 0x1) == 0x1) {
					m_16++;

					/* overflow into exponent */
					if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
						f16_e++;
				}
			}
			f16_m = m_16 & FP16_MASK_M;
		} else if ((be_16 >= -(int)(FP16_MSB_M)) && (be_16 < 1)) {
			/* underflow: zero / subnormal, be_16 = [-9:0] */
			f16_e = 0;

			/* add implicit leading zero */
			m_32 = f32_m | BIT(FP32_LSB_E);
			tbits = FP32_LSB_E - FP16_LSB_E - be_16 + 1;
			m_16 = m_32 >> tbits;

			/* if non-leading truncated bits are set */
			if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) {
				m_16++;

				/* overflow into exponent */
				if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
					f16_e++;
			} else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) {
				/* if leading truncated bit is set */
				if ((m_16 & 0x1) == 0x1) {
					m_16++;

					/* overflow into exponent */
					if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
						f16_e++;
				}
			}
			f16_m = m_16 & FP16_MASK_M;
		} else if (be_16 == -(int)(FP16_MSB_M + 1)) {
			/* underflow: zero, be_16 = [-10] */
			f16_e = 0;
			if (f32_m != 0)
				f16_m = 1;
			else
				f16_m = 0;
		} else {
			/* underflow: zero, be_16 = [-INF:-11] */
			f16_e = 0;
			f16_m = 0;
		}

		break;
	}

	u16 = FP16_PACK(f16_s, f16_e, f16_m);

	return u16;
}

/* Convert a half precision floating point number (float16) into a single precision
 * floating point number (float32).
 */
static float
__float16_to_float32_scalar_rtx(uint16_t f16)
{
	union float32 f32; /* float32 output */
	uint16_t f16_s;	   /* float16 sign */
	uint16_t f16_e;	   /* float16 exponent */
	uint16_t f16_m;	   /* float16 mantissa */
	uint32_t f32_s;	   /* float32 sign */
	uint32_t f32_e;	   /* float32 exponent */
	uint32_t f32_m;	   /* float32 mantissa*/
	uint8_t shift;	   /* number of bits to be shifted */
	uint32_t clz;	   /* count of leading zeroes */
	int e_16;	   /* float16 exponent unbiased */

	f16_s = (f16 & FP16_MASK_S) >> FP16_LSB_S;
	f16_e = (f16 & FP16_MASK_E) >> FP16_LSB_E;
	f16_m = (f16 & FP16_MASK_M) >> FP16_LSB_M;

	f32_s = f16_s;
	switch (f16_e) {
	case (FP16_MASK_E >> FP16_LSB_E): /* float16: infinity or nan */
		f32_e = FP32_MASK_E >> FP32_LSB_E;
		if (f16_m == 0x0) { /* infinity */
			f32_m = f16_m;
		} else { /* nan, propagate mantissa, set MSB of mantissa to 1 */
			f32_m = f16_m;
			shift = FP32_MSB_M - FP16_MSB_M;
			f32_m = (f32_m << shift) & FP32_MASK_M;
			f32_m |= BIT(FP32_MSB_M);
		}
		break;
	case 0: /* float16: zero or sub-normal */
		f32_m = f16_m;
		if (f16_m == 0) { /* zero signed */
			f32_e = 0;
		} else { /* subnormal numbers */
			clz = __builtin_clz((uint32_t)f16_m) - sizeof(uint32_t) * 8 + FP16_LSB_E;
			e_16 = (int)f16_e - clz;
			f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E;

			shift = clz + (FP32_MSB_M - FP16_MSB_M) + 1;
			f32_m = (f32_m << shift) & FP32_MASK_M;
		}
		break;
	default: /* normal numbers */
		f32_m = f16_m;
		e_16 = (int)f16_e;
		f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E;

		shift = (FP32_MSB_M - FP16_MSB_M);
		f32_m = (f32_m << shift) & FP32_MASK_M;
	}

	f32.u = FP32_PACK(f32_s, f32_e, f32_m);

	return f32.f;
}

/* Convert a single precision floating point number (float32) into a
 * brain float number (bfloat16) using round to nearest rounding mode.
 */
static uint16_t
__float32_to_bfloat16_scalar_rtn(float x)
{
	union float32 f32; /* float32 input */
	uint32_t f32_s;	   /* float32 sign */
	uint32_t f32_e;	   /* float32 exponent */
	uint32_t f32_m;	   /* float32 mantissa */
	uint16_t b16_s;	   /* float16 sign */
	uint16_t b16_e;	   /* float16 exponent */
	uint16_t b16_m;	   /* float16 mantissa */
	uint32_t tbits;	   /* number of truncated bits */
	uint16_t u16;	   /* float16 output */

	f32.f = x;
	f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S;
	f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E;
	f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M;

	b16_s = f32_s;
	b16_e = 0;
	b16_m = 0;

	switch (f32_e) {
	case (0): /* float32: zero or subnormal number */
		b16_e = 0;
		if (f32_m == 0) /* zero */
			b16_m = 0;
		else /* subnormal float32 number, normal bfloat16 */
			goto bf16_normal;
		break;
	case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */
		b16_e = BF16_MASK_E >> BF16_LSB_E;
		if (f32_m == 0) { /* infinity */
			b16_m = 0;
		} else { /* nan, propagate mantissa and set MSB of mantissa to 1 */
			b16_m = f32_m >> (FP32_MSB_M - BF16_MSB_M);
			b16_m |= BIT(BF16_MSB_M);
		}
		break;
	default: /* float32: normal number, normal bfloat16 */
		goto bf16_normal;
	}

	goto bf16_pack;

bf16_normal:
	b16_e = f32_e;
	tbits = FP32_MSB_M - BF16_MSB_M;
	b16_m = f32_m >> tbits;

	/* if non-leading truncated bits are set */
	if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) {
		b16_m++;

		/* if overflow into exponent */
		if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1)
			b16_e++;
	} else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) {
		/* if only leading truncated bit is set */
		if ((b16_m & 0x1) == 0x1) {
			b16_m++;

			/* if overflow into exponent */
			if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1)
				b16_e++;
		}
	}
	b16_m = b16_m & BF16_MASK_M;

bf16_pack:
	u16 = BF16_PACK(b16_s, b16_e, b16_m);

	return u16;
}

/* Convert a brain float number (bfloat16) into a
 * single precision floating point number (float32).
 */
static float
__bfloat16_to_float32_scalar_rtx(uint16_t f16)
{
	union float32 f32; /* float32 output */
	uint16_t b16_s;	   /* float16 sign */
	uint16_t b16_e;	   /* float16 exponent */
	uint16_t b16_m;	   /* float16 mantissa */
	uint32_t f32_s;	   /* float32 sign */
	uint32_t f32_e;	   /* float32 exponent */
	uint32_t f32_m;	   /* float32 mantissa*/
	uint8_t shift;	   /* number of bits to be shifted */

	b16_s = (f16 & BF16_MASK_S) >> BF16_LSB_S;
	b16_e = (f16 & BF16_MASK_E) >> BF16_LSB_E;
	b16_m = (f16 & BF16_MASK_M) >> BF16_LSB_M;

	f32_s = b16_s;
	switch (b16_e) {
	case (BF16_MASK_E >> BF16_LSB_E): /* bfloat16: infinity or nan */
		f32_e = FP32_MASK_E >> FP32_LSB_E;
		if (b16_m == 0x0) { /* infinity */
			f32_m = 0;
		} else { /* nan, propagate mantissa, set MSB of mantissa to 1 */
			f32_m = b16_m;
			shift = FP32_MSB_M - BF16_MSB_M;
			f32_m = (f32_m << shift) & FP32_MASK_M;
			f32_m |= BIT(FP32_MSB_M);
		}
		break;
	case 0: /* bfloat16: zero or subnormal */
		f32_m = b16_m;
		if (b16_m == 0) { /* zero signed */
			f32_e = 0;
		} else { /* subnormal numbers */
			goto fp32_normal;
		}
		break;
	default: /* bfloat16: normal number */
		goto fp32_normal;
	}

	goto fp32_pack;

fp32_normal:
	f32_m = b16_m;
	f32_e = FP32_BIAS_E + b16_e - BF16_BIAS_E;

	shift = (FP32_MSB_M - BF16_MSB_M);
	f32_m = (f32_m << shift) & FP32_MASK_M;

fp32_pack:
	f32.u = FP32_PACK(f32_s, f32_e, f32_m);

	return f32.f;
}

uint16_t _odp_float32_to_float16(float x)
{
	return __float32_to_float16_scalar_rtn(x);
}

float _odp_float16_to_float32(uint16_t f16)
{
	return __float16_to_float32_scalar_rtx(f16);
}

uint16_t _odp_float32_to_bfloat16(float x)
{
	return __float32_to_bfloat16_scalar_rtn(x);
}

float _odp_bfloat16_to_float32(uint16_t f16)
{
	return __bfloat16_to_float32_scalar_rtx(f16);
}