RFC: add array support for the shift
kwarg in roll
#914
Labels
RFC
Request for comments. Feature requests and proposed changes.
The Array API states that
xp.roll
accepts two types for the shift argument:int
s andtuple[int, ....]
. However, it is currently possible to run the following piece of code:It is nice that this works, because this allows jitting of the function in jax while using jax arrays as dynamic shifts. However, the spec currently does not guarantee that it works, and therefore I expect the code to fail. Indeed, relying on this behaviour is dangerous. Executing the same snippet with torch yields
Using an explicit integer value instead of a tensor as argument does work.
Remedies
There are two ways to fix this:
array_api_strict
needs to explicitly check the type ofxp.roll
's kwargsIf the spec is updated to include Arrays,
array_api_compat
needs to update its torch wrapper forroll
such that the example above does not lead to any errors. I would be strongly in favour of accepting arrays (both zero and one-dimensional), because it allows us to do more with xp.roll within jit compiled functions.The text was updated successfully, but these errors were encountered: