Skip to content

RFC: add array support for the shift kwarg in roll #914

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
amacati opened this issue Mar 15, 2025 · 2 comments
Open

RFC: add array support for the shift kwarg in roll #914

amacati opened this issue Mar 15, 2025 · 2 comments
Labels
RFC Request for comments. Feature requests and proposed changes.

Comments

@amacati
Copy link

amacati commented Mar 15, 2025

The Array API states that xp.roll accepts two types for the shift argument: ints and tuple[int, ....]. However, it is currently possible to run the following piece of code:

import array_api_strict as xp

a = xp.asarray([1, 2, 3])
b = xp.asarray([1])
print(xp.roll(a, shift=b))  # Type of shift is Array, not tuple[int, ...] or int
# >>> prints [3 1 2]

It is nice that this works, because this allows jitting of the function in jax while using jax arrays as dynamic shifts. However, the spec currently does not guarantee that it works, and therefore I expect the code to fail. Indeed, relying on this behaviour is dangerous. Executing the same snippet with torch yields

import torch
from array_api_compat import array_namespace

xp = array_namespace(torch.tensor(1))
a = xp.asarray([1, 2, 3])
b = xp.asarray([1])
print(xp.roll(a, shift=b))
# Traceback (most recent call last):
#   File "xxx", line 7, in <module>
#     print(xp.roll(a, shift=b))
#           ^^^^^^^^^^^^^^^^^^^
#   File "xxx", line 503, in roll
#     return torch.roll(x, shift, axis, **kwargs)
#            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# RuntimeError: `shifts` required

Using an explicit integer value instead of a tensor as argument does work.

Remedies

There are two ways to fix this:

  1. We can update the spec so that shift accepts the Array type as input
  2. array_api_strict needs to explicitly check the type of xp.roll's kwargs

If the spec is updated to include Arrays, array_api_compat needs to update its torch wrapper for roll such that the example above does not lead to any errors. I would be strongly in favour of accepting arrays (both zero and one-dimensional), because it allows us to do more with xp.roll within jit compiled functions.

@lucascolley
Copy link
Member

Thanks @amacati. I propose that we go with option 2, transferring this issue to the array-api-strict repo, since int | tuple(int) seems standard across https://data-apis.org/array-api/draft/API_specification/manipulation_functions.html. Would you agree @kgryte ?

@kgryte
Copy link
Contributor

kgryte commented Apr 2, 2025

@lucascolley It depends. According to docs,

  • NumPy: int or tuple of ints
  • CuPy: int or tuple of ints
  • PyTorch: int or tuple of ints
  • JAX: ArrayLike | Sequence[int]
  • Dask: int or tuple of ints
  • MLX: int or tuple of ints
  • ndonnx: int or tuple of ints

Based on the above, @amacati, it looks like only JAX officially supports arrays. We'd probably need more consensus to move forward updating the standard, and we'd need someone to champion this change in the respective array libraries.

In the meantime, array-api-strict should be more strict.

@kgryte kgryte changed the title Clarification on xp.roll's signature RFC: add array support for the shift kwarg in roll Apr 2, 2025
@kgryte kgryte added the RFC Request for comments. Feature requests and proposed changes. label Apr 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
RFC Request for comments. Feature requests and proposed changes.
Projects
None yet
Development

No branches or pull requests

3 participants