-
Notifications
You must be signed in to change notification settings - Fork 45
Broadcasting tests #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
See my answer on gh-2.
I think broadcasting is a behaviour that needs testing for every function that has broadcast-able inputs, so it's a big job. For now I'd take the one or a handful of functions you'll start with, and see if you can write a broadcast test with shape logic that can be reused.
Yes this is a pain, we really should have such a function in NumPy. I think we need it here. The algorithm isn't actually all that hard - I suspect there's an implementation in pure Python floating around somewhere.
Ah. I had to read that section again, you're right - it's missing a statement on how values in the array being broadcasted are "virtually duplicated" (not sure that's great wording).
Indeed. Strides are explicitly out of scope - there are arrays that do not have a strided implementation (e.g. JAX, Dask). That's an implementation detail, and the spec attempts to be careful not to specify anything that puts unnecessary constraints on any library - JIT compilers, lazy evaluation, distributed arrays, GPU arrays, etc. all should be possible. |
Re: what an array looks like. That is covered, perhaps too obliquely, here:
The key phrase being "virtually repeated". May be worthwhile for me to add an example to the spec doc to show what the broadcasting algorithm means in practice. |
I was actually going to open an issue on NumPy about this some time ago, but never got around to it. I've done so here numpy/numpy#17217. |
I believe there is a bug in the broadcasting pseudocode. If I implement it directly, I get IndexErrors from the test cases. I believe the correct code should be
(also note that the spec says Here is the function as defined in the spec class BroadcastError(Exception):
pass
def broadcast_shapes(shape1, shape2):
"""
Broadcast shapes `shape1` and `shape2`.
The code in this function should follow the pseudocode in the spec as
closely as possible.
"""
N1 = len(shape1)
N2 = len(shape2)
N = max(N1, N2)
shape = [None]*N
i = N - 1
while i >= 0:
if N1 - N + i >= 0:
d1 = shape1[i]
else:
d1 = 1
if N2 - N + i >= 0:
d2 = shape2[i]
else:
d2 = 1
if d1 == 1:
shape[i] = d2
elif d2 == 1:
shape[i] = d1
elif d1 == d2:
shape[i] = d1
else:
raise BroadcastError
i = i - 1
return tuple(shape) And here is what I believe is the correct version class BroadcastError(Exception):
pass
def broadcast_shapes(shape1, shape2):
"""
Broadcast shapes `shape1` and `shape2`.
The code in this function should follow the pseudocode in the spec as
closely as possible.
"""
N1 = len(shape1)
N2 = len(shape2)
N = max(N1, N2)
shape = [None]*N
i = N - 1
while i >= 0:
if N1 - N + i >= 0:
d1 = shape1[N1 - N + i] # This line is different
else:
d1 = 1
if N2 - N + i >= 0:
d2 = shape2[N2 - N + i] # This line is different
else:
d2 = 1
if d1 == 1:
shape[i] = d2
elif d2 == 1:
shape[i] = d1
elif d1 == d2:
shape[i] = d1
else:
raise BroadcastError
i = i - 1
return tuple(shape) And here are the test cases from the spec def test_broadcast_shapes_explicit_spec():
"""
Explicit broadcast shapes examples from the spec
"""
shape1 = (8, 1, 6, 1)
shape2 = (7, 1, 5)
result = (8, 7, 6, 5)
assert broadcast_shapes(shape1, shape2) == result
shape1 = (5, 4)
shape2 = (1,)
result = (5, 4)
assert broadcast_shapes(shape1, shape2) == result
shape1 = (5, 4)
shape2 = (4,)
result = (5, 4)
assert broadcast_shapes(shape1, shape2) == result
shape1 = (15, 3, 5)
shape2 = (15, 1, 5)
result = (15, 3, 5)
assert broadcast_shapes(shape1, shape2) == result
shape1 = (15, 3, 5)
shape2 = (3, 5)
result = (15, 3, 5)
assert broadcast_shapes(shape1, shape2) == result
shape1 = (15, 3, 5)
shape2 = (3, 1)
result = (15, 3, 5)
assert broadcast_shapes(shape1, shape2) == result
shape1 = (3,)
shape2 = (4,)
raises(BroadcastError, lambda: broadcast_shapes(shape1, shape2)) # dimension does not match
shape1 = (2, 1)
shape2 = (8, 4, 3)
raises(BroadcastError, lambda: broadcast_shapes(shape1, shape2)) # second dimension does not match
shape1 = (15, 3, 5)
shape2 = (15, 3)
raises(BroadcastError, lambda: broadcast_shapes(shape1, shape2)) # singleton dimensions can only be prepended, not appended |
@asmeurer Thanks for catching this. Will submit a patch to the spec shortly. |
@asmeurer Patch submitted: data-apis/array-api#30. |
We have tests for broadcasting but they only cover the elementwise functions. We need to extend it to operators (including the in-place operator logic), as well as other functions. c.f. #24 |
Here is the spec for broadcasting https://github.com/data-apis/array-api/blob/master/spec/API_specification/broadcasting.md.
Here are some questions about it:
How do I create input arrays for the test? The array object document is empty https://github.com/data-apis/array-api/blob/master/spec/API_specification/array_object.md. Do we have at least enough an idea of what that will look like so I can create tests?
What is the best way to test broadcasting? The simplest would be to use a function like
numpy.broadcast_arrays
ornumpy.broadcast_to
, but these aren't listed in the spec. And even NumPy doesn't have a function that directly implements the shape broadcasting algorithm—it can only be done to explicit arrays. The spec says broadcasting should apply to all elementwise operations. What is a good elementwise operation that we can use to test only the broadcasting semantics? Or should we make sure to test all of them?The spec doesn't actually specify how resulting broadcast array should look, only what its shape is. Is this intentional? Should we test this? If not, it means we don't actually test the result of a broadcasted operation, only that the shape/errors are correct.
As I understand it, "potentially enable more memory-efficient element-wise operations" means that broadcasting does not necessarily need to be done in a memory-efficient way, i.e., libraries are free to copy axes across broadcast dimensions rather than using something like a stride trick.
The text was updated successfully, but these errors were encountered: