Output of torch.sum
with unsigned input should be unsigned
#242
Labels
torch.sum
with unsigned input should be unsigned
#242
According to the standard, the documentation of
sum
states for thedtype
parameter:If I understand correctly, then the sums for unsigned dtype below should have
uint64
dtype:But the output is:
I think this is at least partially fixable within
array-api-compat
.Also,
torch
doesn't seem to natively supportsum
for mostuint
dtypes orcomplex32
. If we changexp.sum(x)
toxp.sum(x, dtype=dtype)
in the code above, the output is:It would be helpful if
array-api-compat
would implementsum
for these types even if that means upcasting to a supported type before summing and then downcasting. (There is a slightly larger chance of overflow withint64
than withuint64
, and it's possible that the conversion will not be safe, so it's up for discussion what should happen in those cases.)Does array-api-compat have a mechanism for reporting the shortcomings it has to patch to the underlying libraries? If not, should I report this to PyTorch (if it is not already reported)?
The text was updated successfully, but these errors were encountered: