@@ -274,6 +274,10 @@ def diff(a_tensor, n=1, axis=-1, prepend_tensor=None, append_tensor=None):
274
274
if n < 0 :
275
275
raise ValueError (f"order must be non-negative but got { n } " )
276
276
277
+ if n == 0 :
278
+ # match numpy and return the input immediately
279
+ return a_tensor
280
+
277
281
if prepend_tensor is not None :
278
282
shape = list (a_tensor .shape )
279
283
shape [axis ] = prepend_tensor .shape [axis ] if prepend_tensor .ndim > 0 else 1
@@ -357,6 +361,14 @@ def vstack(tensors, *, dtype=None, casting="same_kind"):
357
361
return result
358
362
359
363
364
+ def tile (tensor , reps ):
365
+ if isinstance (reps , int ):
366
+ reps = (reps ,)
367
+
368
+ result = torch .tile (tensor , reps )
369
+ return result
370
+
371
+
360
372
# #### cov & corrcoef
361
373
362
374
@@ -450,6 +462,10 @@ def indices(dimensions, dtype=int, sparse=False):
450
462
451
463
452
464
def bincount (x_tensor , / , weights_tensor = None , minlength = 0 ):
465
+ if x_tensor .numel () == 0 :
466
+ # edge case allowed by numpy
467
+ x_tensor = torch .as_tensor ([], dtype = int )
468
+
453
469
int_dtype = _dtypes_impl .default_int_dtype
454
470
(x_tensor ,) = _util .cast_dont_broadcast ((x_tensor ,), int_dtype , casting = "safe" )
455
471
@@ -460,6 +476,14 @@ def bincount(x_tensor, /, weights_tensor=None, minlength=0):
460
476
# ### linspace, geomspace, logspace and arange ###
461
477
462
478
479
+ def linspace (start , stop , num = 50 , endpoint = True , retstep = False , dtype = None , axis = 0 ):
480
+ if axis != 0 or retstep or not endpoint :
481
+ raise NotImplementedError
482
+ # XXX: raises TypeError if start or stop are not scalars
483
+ result = torch .linspace (start , stop , num , dtype = dtype )
484
+ return result
485
+
486
+
463
487
def geomspace (start , stop , num = 50 , endpoint = True , dtype = None , axis = 0 ):
464
488
if axis != 0 or not endpoint :
465
489
raise NotImplementedError
@@ -474,6 +498,13 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
474
498
return result
475
499
476
500
501
+ def logspace (start , stop , num = 50 , endpoint = True , base = 10.0 , dtype = None , axis = 0 ):
502
+ if axis != 0 or not endpoint :
503
+ raise NotImplementedError
504
+ result = torch .logspace (start , stop , num , base = base , dtype = dtype )
505
+ return result
506
+
507
+
477
508
def arange (start = None , stop = None , step = 1 , dtype = None ):
478
509
if step == 0 :
479
510
raise ZeroDivisionError
@@ -523,36 +554,75 @@ def eye(N, M=None, k=0, dtype=float):
523
554
return z
524
555
525
556
526
- def zeros_like (a , dtype = None , shape = None ):
557
+ def zeros (shape , dtype = None , order = "C" ):
558
+ if order != "C" :
559
+ raise NotImplementedError
560
+ if dtype is None :
561
+ dtype = _dtypes_impl .default_float_dtype
562
+ result = torch .zeros (shape , dtype = dtype )
563
+ return result
564
+
565
+
566
+ def zeros_like (a , dtype = None , shape = None , order = "K" ):
567
+ if order != "K" :
568
+ raise NotImplementedError
527
569
result = torch .zeros_like (a , dtype = dtype )
528
570
if shape is not None :
529
571
result = result .reshape (shape )
530
572
return result
531
573
532
574
533
- def ones_like (a , dtype = None , shape = None ):
575
+ def ones (shape , dtype = None , order = "C" ):
576
+ if order != "C" :
577
+ raise NotImplementedError
578
+ if dtype is None :
579
+ dtype = _dtypes_impl .default_float_dtype
580
+ result = torch .ones (shape , dtype = dtype )
581
+ return result
582
+
583
+
584
+ def ones_like (a , dtype = None , shape = None , order = "K" ):
585
+ if order != "K" :
586
+ raise NotImplementedError
534
587
result = torch .ones_like (a , dtype = dtype )
535
588
if shape is not None :
536
589
result = result .reshape (shape )
537
590
return result
538
591
539
592
540
- def full_like (a , fill_value , dtype = None , shape = None ):
593
+ def full_like (a , fill_value , dtype = None , shape = None , order = "K" ):
594
+ if order != "K" :
595
+ raise NotImplementedError
541
596
# XXX: fill_value broadcasts
542
597
result = torch .full_like (a , fill_value , dtype = dtype )
543
598
if shape is not None :
544
599
result = result .reshape (shape )
545
600
return result
546
601
547
602
548
- def empty_like (prototype , dtype = None , shape = None ):
603
+ def empty (shape , dtype = None , order = "C" ):
604
+ if order != "C" :
605
+ raise NotImplementedError
606
+ if dtype is None :
607
+ dtype = _dtypes_impl .default_float_dtype
608
+ result = torch .empty (shape , dtype = dtype )
609
+ return result
610
+
611
+
612
+ def empty_like (prototype , dtype = None , shape = None , order = "K" ):
613
+ if order != "K" :
614
+ raise NotImplementedError
549
615
result = torch .empty_like (prototype , dtype = dtype )
550
616
if shape is not None :
551
617
result = result .reshape (shape )
552
618
return result
553
619
554
620
555
- def full (shape , fill_value , dtype = None ):
621
+ def full (shape , fill_value , dtype = None , order = "C" ):
622
+ if isinstance (shape , int ):
623
+ shape = (shape ,)
624
+ if order != "C" :
625
+ raise NotImplementedError
556
626
if dtype is None :
557
627
dtype = fill_value .dtype
558
628
if not isinstance (shape , (tuple , list )):
0 commit comments