8
8
#[ inline]
9
9
#[ target_feature( enable = "amx-tile" ) ]
10
10
#[ unstable( feature = "x86_amx_intrinsics" , issue = "126622" ) ]
11
- pub unsafe fn _tile_loadconfig ( mem_addr : * const i8 ) {
11
+ pub unsafe fn _tile_loadconfig ( mem_addr : * const u8 ) {
12
12
ldtilecfg ( mem_addr) ;
13
13
}
14
14
@@ -20,7 +20,7 @@ pub unsafe fn _tile_loadconfig(mem_addr: *const i8) {
20
20
#[ inline]
21
21
#[ target_feature( enable = "amx-tile" ) ]
22
22
#[ unstable( feature = "x86_amx_intrinsics" , issue = "126622" ) ]
23
- pub unsafe fn _tile_storeconfig ( mem_addr : * mut i8 ) {
23
+ pub unsafe fn _tile_storeconfig ( mem_addr : * mut u8 ) {
24
24
sttilecfg ( mem_addr) ;
25
25
}
26
26
@@ -31,7 +31,7 @@ pub unsafe fn _tile_storeconfig(mem_addr: *mut i8) {
31
31
#[ rustc_legacy_const_generics( 0 ) ]
32
32
#[ target_feature( enable = "amx-tile" ) ]
33
33
#[ unstable( feature = "x86_amx_intrinsics" , issue = "126622" ) ]
34
- pub unsafe fn _tile_loadd < const DST : i8 > ( base : * const i8 , stride : usize ) {
34
+ pub unsafe fn _tile_loadd < const DST : i8 > ( base : * const u8 , stride : usize ) {
35
35
static_assert_uimm_bits ! ( DST , 3 ) ;
36
36
tileloadd64 ( DST , base, stride) ;
37
37
}
@@ -53,7 +53,7 @@ pub unsafe fn _tile_release() {
53
53
#[ rustc_legacy_const_generics( 0 ) ]
54
54
#[ target_feature( enable = "amx-tile" ) ]
55
55
#[ unstable( feature = "x86_amx_intrinsics" , issue = "126622" ) ]
56
- pub unsafe fn _tile_stored < const DST : i8 > ( base : * mut i8 , stride : usize ) {
56
+ pub unsafe fn _tile_stored < const DST : i8 > ( base : * mut u8 , stride : usize ) {
57
57
static_assert_uimm_bits ! ( DST , 3 ) ;
58
58
tilestored64 ( DST , base, stride) ;
59
59
}
@@ -67,7 +67,7 @@ pub unsafe fn _tile_stored<const DST: i8>(base: *mut i8, stride: usize) {
67
67
#[ rustc_legacy_const_generics( 0 ) ]
68
68
#[ target_feature( enable = "amx-tile" ) ]
69
69
#[ unstable( feature = "x86_amx_intrinsics" , issue = "126622" ) ]
70
- pub unsafe fn _tile_stream_loadd < const DST : i8 > ( base : * const i8 , stride : usize ) {
70
+ pub unsafe fn _tile_stream_loadd < const DST : i8 > ( base : * const u8 , stride : usize ) {
71
71
static_assert_uimm_bits ! ( DST , 3 ) ;
72
72
tileloaddt164 ( DST , base, stride) ;
73
73
}
@@ -227,17 +227,17 @@ pub unsafe fn _tile_cmmrlfp16ps<const DST: i8, const A: i8, const B: i8>() {
227
227
#[ allow( improper_ctypes) ]
228
228
extern "C" {
229
229
#[ link_name = "llvm.x86.ldtilecfg" ]
230
- fn ldtilecfg ( mem_addr : * const i8 ) ;
230
+ fn ldtilecfg ( mem_addr : * const u8 ) ;
231
231
#[ link_name = "llvm.x86.sttilecfg" ]
232
- fn sttilecfg ( mem_addr : * mut i8 ) ;
232
+ fn sttilecfg ( mem_addr : * mut u8 ) ;
233
233
#[ link_name = "llvm.x86.tileloadd64" ]
234
- fn tileloadd64 ( dst : i8 , base : * const i8 , stride : usize ) ;
234
+ fn tileloadd64 ( dst : i8 , base : * const u8 , stride : usize ) ;
235
235
#[ link_name = "llvm.x86.tileloaddt164" ]
236
- fn tileloaddt164 ( dst : i8 , base : * const i8 , stride : usize ) ;
236
+ fn tileloaddt164 ( dst : i8 , base : * const u8 , stride : usize ) ;
237
237
#[ link_name = "llvm.x86.tilerelease" ]
238
238
fn tilerelease ( ) ;
239
239
#[ link_name = "llvm.x86.tilestored64" ]
240
- fn tilestored64 ( dst : i8 , base : * mut i8 , stride : usize ) ;
240
+ fn tilestored64 ( dst : i8 , base : * mut u8 , stride : usize ) ;
241
241
#[ link_name = "llvm.x86.tilezero" ]
242
242
fn tilezero ( dst : i8 ) ;
243
243
#[ link_name = "llvm.x86.tdpbf16ps" ]
@@ -267,6 +267,47 @@ mod tests {
267
267
#[ cfg( target_os = "linux" ) ]
268
268
use syscalls:: { syscall, Sysno } ;
269
269
270
+ #[ allow( non_camel_case_types) ]
271
+ #[ repr( packed) ]
272
+ #[ derive( Copy , Clone , Default , Debug , PartialEq ) ]
273
+ struct __tilecfg {
274
+ /// 0 `or` 1
275
+ palette : u8 ,
276
+ start_row : u8 ,
277
+ /// reserved, must be zero
278
+ reserved_a0 : [ u8 ; 14 ] ,
279
+ /// number of bytes of one row in each tile
280
+ colsb : [ u16 ; 8 ] ,
281
+ /// reserved, must be zero
282
+ reserved_b0 : [ u16 ; 8 ] ,
283
+ /// number of rows in each tile
284
+ rows : [ u8 ; 8 ] ,
285
+ /// reserved, must be zero
286
+ reserved_c0 : [ u8 ; 8 ] ,
287
+ }
288
+
289
+ impl __tilecfg {
290
+ fn new ( palette : u8 , start_row : u8 , colsb : [ u16 ; 8 ] , rows : [ u8 ; 8 ] ) -> Self {
291
+ Self {
292
+ palette,
293
+ start_row,
294
+ reserved_a0 : [ 0u8 ; 14 ] ,
295
+ colsb,
296
+ reserved_b0 : [ 0u16 ; 8 ] ,
297
+ rows,
298
+ reserved_c0 : [ 0u8 ; 8 ] ,
299
+ }
300
+ }
301
+
302
+ const fn as_ptr ( & self ) -> * const u8 {
303
+ self as * const Self as * const u8
304
+ }
305
+
306
+ fn as_mut_ptr ( & mut self ) -> * mut u8 {
307
+ self as * mut Self as * mut u8
308
+ }
309
+ }
310
+
270
311
#[ cfg( not( target_os = "linux" ) ) ]
271
312
#[ target_feature( enable = "amx-tile" ) ]
272
313
fn _init_amx ( ) { }
@@ -324,7 +365,7 @@ mod tests {
324
365
_tile_loadconfig ( config. as_ptr ( ) ) ;
325
366
_tile_zero :: < 0 > ( ) ;
326
367
let mut out = [ [ 1_i8 ; 64 ] ; 16 ] ;
327
- _tile_stored :: < 0 > ( & mut out as * mut [ i8 ; 64 ] as * mut i8 , 64 ) ;
368
+ _tile_stored :: < 0 > ( & mut out as * mut [ i8 ; 64 ] as * mut u8 , 64 ) ;
328
369
_tile_release ( ) ;
329
370
assert_eq ! ( out, [ [ 0 ; 64 ] ; 16 ] ) ;
330
371
}
@@ -339,7 +380,7 @@ mod tests {
339
380
_tile_loadconfig ( config. as_ptr ( ) ) ;
340
381
_tile_zero :: < 0 > ( ) ;
341
382
let mut out = [ [ 1_i8 ; 64 ] ; 16 ] ;
342
- _tile_stored :: < 0 > ( & mut out as * mut [ i8 ; 64 ] as * mut i8 , 64 ) ;
383
+ _tile_stored :: < 0 > ( & mut out as * mut [ i8 ; 64 ] as * mut u8 , 64 ) ;
343
384
_tile_release ( ) ;
344
385
assert_eq ! ( out, [ [ 0 ; 64 ] ; 16 ] ) ;
345
386
}
@@ -354,9 +395,9 @@ mod tests {
354
395
_tile_loadconfig ( config. as_ptr ( ) ) ;
355
396
_tile_zero :: < 0 > ( ) ;
356
397
let mat = [ 1_i8 ; 1024 ] ;
357
- _tile_loadd :: < 0 > ( & mat as * const i8 , 64 ) ;
398
+ _tile_loadd :: < 0 > ( & mat as * const i8 as * const u8 , 64 ) ;
358
399
let mut out = [ [ 0_i8 ; 64 ] ; 16 ] ;
359
- _tile_stored :: < 0 > ( & mut out as * mut [ i8 ; 64 ] as * mut i8 , 64 ) ;
400
+ _tile_stored :: < 0 > ( & mut out as * mut [ i8 ; 64 ] as * mut u8 , 64 ) ;
360
401
_tile_release ( ) ;
361
402
assert_eq ! ( out, [ [ 1 ; 64 ] ; 16 ] ) ;
362
403
}
@@ -371,9 +412,9 @@ mod tests {
371
412
_tile_loadconfig ( config. as_ptr ( ) ) ;
372
413
_tile_zero :: < 0 > ( ) ;
373
414
let mat = [ 1_i8 ; 1024 ] ;
374
- _tile_stream_loadd :: < 0 > ( & mat as * const i8 , 64 ) ;
415
+ _tile_stream_loadd :: < 0 > ( & mat as * const i8 as * const u8 , 64 ) ;
375
416
let mut out = [ [ 0_i8 ; 64 ] ; 16 ] ;
376
- _tile_stored :: < 0 > ( & mut out as * mut [ i8 ; 64 ] as * mut i8 , 64 ) ;
417
+ _tile_stored :: < 0 > ( & mut out as * mut [ i8 ; 64 ] as * mut u8 , 64 ) ;
377
418
_tile_release ( ) ;
378
419
assert_eq ! ( out, [ [ 1 ; 64 ] ; 16 ] ) ;
379
420
}
@@ -388,8 +429,8 @@ mod tests {
388
429
_init_amx ( ) ;
389
430
let bf16_1: u16 = _mm_cvtness_sbh ( 1.0 ) . to_bits ( ) ;
390
431
let bf16_2: u16 = _mm_cvtness_sbh ( 2.0 ) . to_bits ( ) ;
391
- let ones: [ i8 ; 1024 ] = transmute ( [ bf16_1; 512 ] ) ;
392
- let twos: [ i8 ; 1024 ] = transmute ( [ bf16_2; 512 ] ) ;
432
+ let ones: [ u8 ; 1024 ] = transmute ( [ bf16_1; 512 ] ) ;
433
+ let twos: [ u8 ; 1024 ] = transmute ( [ bf16_2; 512 ] ) ;
393
434
let mut res = [ [ 0f32 ; 16 ] ; 16 ] ;
394
435
let mut config = __tilecfg:: default ( ) ;
395
436
config. palette = 1 ;
@@ -399,10 +440,10 @@ mod tests {
399
440
} ) ;
400
441
_tile_loadconfig ( config. as_ptr ( ) ) ;
401
442
_tile_zero :: < 0 > ( ) ;
402
- _tile_loadd :: < 1 > ( & ones as * const i8 , 64 ) ;
403
- _tile_loadd :: < 2 > ( & twos as * const i8 , 64 ) ;
443
+ _tile_loadd :: < 1 > ( & ones as * const u8 , 64 ) ;
444
+ _tile_loadd :: < 2 > ( & twos as * const u8 , 64 ) ;
404
445
_tile_dpbf16ps :: < 0 , 1 , 2 > ( ) ;
405
- _tile_stored :: < 0 > ( & mut res as * mut [ f32 ; 16 ] as * mut i8 , 64 ) ;
446
+ _tile_stored :: < 0 > ( & mut res as * mut [ f32 ; 16 ] as * mut u8 , 64 ) ;
406
447
_tile_release ( ) ;
407
448
assert_eq ! ( res, [ [ 64f32 ; 16 ] ; 16 ] ) ;
408
449
}
@@ -421,10 +462,10 @@ mod tests {
421
462
} ) ;
422
463
_tile_loadconfig ( config. as_ptr ( ) ) ;
423
464
_tile_zero :: < 0 > ( ) ;
424
- _tile_loadd :: < 1 > ( & ones as * const i8 , 64 ) ;
425
- _tile_loadd :: < 2 > ( & twos as * const i8 , 64 ) ;
465
+ _tile_loadd :: < 1 > ( & ones as * const i8 as * const u8 , 64 ) ;
466
+ _tile_loadd :: < 2 > ( & twos as * const i8 as * const u8 , 64 ) ;
426
467
_tile_dpbssd :: < 0 , 1 , 2 > ( ) ;
427
- _tile_stored :: < 0 > ( & mut res as * mut [ i32 ; 16 ] as * mut i8 , 64 ) ;
468
+ _tile_stored :: < 0 > ( & mut res as * mut [ i32 ; 16 ] as * mut u8 , 64 ) ;
428
469
_tile_release ( ) ;
429
470
assert_eq ! ( res, [ [ 128_i32 ; 16 ] ; 16 ] ) ;
430
471
}
@@ -443,10 +484,10 @@ mod tests {
443
484
} ) ;
444
485
_tile_loadconfig ( config. as_ptr ( ) ) ;
445
486
_tile_zero :: < 0 > ( ) ;
446
- _tile_loadd :: < 1 > ( & ones as * const i8 , 64 ) ;
447
- _tile_loadd :: < 2 > ( & twos as * const u8 as * const i8 , 64 ) ;
487
+ _tile_loadd :: < 1 > ( & ones as * const i8 as * const u8 , 64 ) ;
488
+ _tile_loadd :: < 2 > ( & twos as * const u8 , 64 ) ;
448
489
_tile_dpbsud :: < 0 , 1 , 2 > ( ) ;
449
- _tile_stored :: < 0 > ( & mut res as * mut [ i32 ; 16 ] as * mut i8 , 64 ) ;
490
+ _tile_stored :: < 0 > ( & mut res as * mut [ i32 ; 16 ] as * mut u8 , 64 ) ;
450
491
_tile_release ( ) ;
451
492
assert_eq ! ( res, [ [ -128_i32 ; 16 ] ; 16 ] ) ;
452
493
}
@@ -465,10 +506,10 @@ mod tests {
465
506
} ) ;
466
507
_tile_loadconfig ( config. as_ptr ( ) ) ;
467
508
_tile_zero :: < 0 > ( ) ;
468
- _tile_loadd :: < 1 > ( & ones as * const u8 as * const i8 , 64 ) ;
469
- _tile_loadd :: < 2 > ( & twos as * const i8 , 64 ) ;
509
+ _tile_loadd :: < 1 > ( & ones as * const u8 , 64 ) ;
510
+ _tile_loadd :: < 2 > ( & twos as * const i8 as * const u8 , 64 ) ;
470
511
_tile_dpbusd :: < 0 , 1 , 2 > ( ) ;
471
- _tile_stored :: < 0 > ( & mut res as * mut [ i32 ; 16 ] as * mut i8 , 64 ) ;
512
+ _tile_stored :: < 0 > ( & mut res as * mut [ i32 ; 16 ] as * mut u8 , 64 ) ;
472
513
_tile_release ( ) ;
473
514
assert_eq ! ( res, [ [ -128_i32 ; 16 ] ; 16 ] ) ;
474
515
}
@@ -487,10 +528,10 @@ mod tests {
487
528
} ) ;
488
529
_tile_loadconfig ( config. as_ptr ( ) ) ;
489
530
_tile_zero :: < 0 > ( ) ;
490
- _tile_loadd :: < 1 > ( & ones as * const u8 as * const i8 , 64 ) ;
491
- _tile_loadd :: < 2 > ( & twos as * const u8 as * const i8 , 64 ) ;
531
+ _tile_loadd :: < 1 > ( & ones as * const u8 , 64 ) ;
532
+ _tile_loadd :: < 2 > ( & twos as * const u8 , 64 ) ;
492
533
_tile_dpbuud :: < 0 , 1 , 2 > ( ) ;
493
- _tile_stored :: < 0 > ( & mut res as * mut [ i32 ; 16 ] as * mut i8 , 64 ) ;
534
+ _tile_stored :: < 0 > ( & mut res as * mut [ i32 ; 16 ] as * mut u8 , 64 ) ;
494
535
_tile_release ( ) ;
495
536
assert_eq ! ( res, [ [ 128_i32 ; 16 ] ; 16 ] ) ;
496
537
}
@@ -509,10 +550,10 @@ mod tests {
509
550
} ) ;
510
551
_tile_loadconfig ( config. as_ptr ( ) ) ;
511
552
_tile_zero :: < 0 > ( ) ;
512
- _tile_loadd :: < 1 > ( & ones as * const f16 as * const i8 , 64 ) ;
513
- _tile_loadd :: < 2 > ( & twos as * const f16 as * const i8 , 64 ) ;
553
+ _tile_loadd :: < 1 > ( & ones as * const f16 as * const u8 , 64 ) ;
554
+ _tile_loadd :: < 2 > ( & twos as * const f16 as * const u8 , 64 ) ;
514
555
_tile_dpfp16ps :: < 0 , 1 , 2 > ( ) ;
515
- _tile_stored :: < 0 > ( & mut res as * mut [ f32 ; 16 ] as * mut i8 , 64 ) ;
556
+ _tile_stored :: < 0 > ( & mut res as * mut [ f32 ; 16 ] as * mut u8 , 64 ) ;
516
557
_tile_release ( ) ;
517
558
assert_eq ! ( res, [ [ 64f32 ; 16 ] ; 16 ] ) ;
518
559
}
@@ -531,10 +572,10 @@ mod tests {
531
572
} ) ;
532
573
_tile_loadconfig ( config. as_ptr ( ) ) ;
533
574
_tile_zero :: < 0 > ( ) ;
534
- _tile_loadd :: < 1 > ( & ones as * const f16 as * const i8 , 64 ) ;
535
- _tile_loadd :: < 2 > ( & twos as * const f16 as * const i8 , 64 ) ;
575
+ _tile_loadd :: < 1 > ( & ones as * const f16 as * const u8 , 64 ) ;
576
+ _tile_loadd :: < 2 > ( & twos as * const f16 as * const u8 , 64 ) ;
536
577
_tile_cmmimfp16ps :: < 0 , 1 , 2 > ( ) ;
537
- _tile_stored :: < 0 > ( & mut res as * mut [ f32 ; 16 ] as * mut i8 , 64 ) ;
578
+ _tile_stored :: < 0 > ( & mut res as * mut [ f32 ; 16 ] as * mut u8 , 64 ) ;
538
579
_tile_release ( ) ;
539
580
assert_eq ! ( res, [ [ 64f32 ; 16 ] ; 16 ] ) ;
540
581
}
@@ -553,10 +594,10 @@ mod tests {
553
594
} ) ;
554
595
_tile_loadconfig ( config. as_ptr ( ) ) ;
555
596
_tile_zero :: < 0 > ( ) ;
556
- _tile_loadd :: < 1 > ( & ones as * const f16 as * const i8 , 64 ) ;
557
- _tile_loadd :: < 2 > ( & twos as * const f16 as * const i8 , 64 ) ;
597
+ _tile_loadd :: < 1 > ( & ones as * const f16 as * const u8 , 64 ) ;
598
+ _tile_loadd :: < 2 > ( & twos as * const f16 as * const u8 , 64 ) ;
558
599
_tile_cmmrlfp16ps :: < 0 , 1 , 2 > ( ) ;
559
- _tile_stored :: < 0 > ( & mut res as * mut [ f32 ; 16 ] as * mut i8 , 64 ) ;
600
+ _tile_stored :: < 0 > ( & mut res as * mut [ f32 ; 16 ] as * mut u8 , 64 ) ;
560
601
_tile_release ( ) ;
561
602
assert_eq ! ( res, [ [ 0f32 ; 16 ] ; 16 ] ) ;
562
603
}
0 commit comments