Viewing file: test_stride_tricks.py (6.79 KB) -rw-r--r-- Select action/file-type: (+) | (+) | (+) | Code (+) | Session (+) | (+) | SDB (+) | (+) | (+) | (+) | (+) | (+) |
import numpy as np from numpy.testing import * from numpy.lib.stride_tricks import broadcast_arrays
def assert_shapes_correct(input_shapes, expected_shape): """ Broadcast a list of arrays with the given input shapes and check the common output shape. """ inarrays = [np.zeros(s) for s in input_shapes] outarrays = broadcast_arrays(*inarrays) outshapes = [a.shape for a in outarrays] expected = [expected_shape] * len(inarrays) assert outshapes == expected
def assert_incompatible_shapes_raise(input_shapes): """ Broadcast a list of arrays with the given (incompatible) input shapes and check that they raise a ValueError. """ inarrays = [np.zeros(s) for s in input_shapes] assert_raises(ValueError, broadcast_arrays, *inarrays)
def assert_same_as_ufunc(shape0, shape1, transposed=False, flipped=False): """ Broadcast two shapes against each other and check that the data layout is the same as if a ufunc did the broadcasting. """ x0 = np.zeros(shape0, dtype=int) # Note that multiply.reduce's identity element is 1.0, so when shape1==(), # this gives the desired n==1. n = int(np.multiply.reduce(shape1)) x1 = np.arange(n).reshape(shape1) if transposed: x0 = x0.T x1 = x1.T if flipped: x0 = x0[::-1] x1 = x1[::-1] # Use the add ufunc to do the broadcasting. Since we're adding 0s to x1, the # result should be exactly the same as the broadcasted view of x1. y = x0 + x1 b0, b1 = broadcast_arrays(x0, x1) assert_array_equal(y, b1)
def test_same(): x = np.arange(10) y = np.arange(10) bx, by = broadcast_arrays(x, y) assert_array_equal(x, bx) assert_array_equal(y, by)
def test_one_off(): x = np.array([[1,2,3]]) y = np.array([[1],[2],[3]]) bx, by = broadcast_arrays(x, y) bx0 = np.array([[1,2,3],[1,2,3],[1,2,3]]) by0 = bx0.T assert_array_equal(bx0, bx) assert_array_equal(by0, by)
def test_same_input_shapes(): """ Check that the final shape is just the input shape. """ data = [ (), (1,), (3,), (0,1), (0,3), (1,0), (3,0), (1,3), (3,1), (3,3), ] for shape in data: input_shapes = [shape] # Single input. yield assert_shapes_correct, input_shapes, shape # Double input. input_shapes2 = [shape, shape] yield assert_shapes_correct, input_shapes2, shape # Triple input. input_shapes3 = [shape, shape, shape] yield assert_shapes_correct, input_shapes3, shape
def test_two_compatible_by_ones_input_shapes(): """ Check that two different input shapes (of the same length but some have 1s) broadcast to the correct shape. """ data = [ [[(1,), (3,)], (3,)], [[(1,3), (3,3)], (3,3)], [[(3,1), (3,3)], (3,3)], [[(1,3), (3,1)], (3,3)], [[(1,1), (3,3)], (3,3)], [[(1,1), (1,3)], (1,3)], [[(1,1), (3,1)], (3,1)], [[(1,0), (0,0)], (0,0)], [[(0,1), (0,0)], (0,0)], [[(1,0), (0,1)], (0,0)], [[(1,1), (0,0)], (0,0)], [[(1,1), (1,0)], (1,0)], [[(1,1), (0,1)], (0,1)], ] for input_shapes, expected_shape in data: yield assert_shapes_correct, input_shapes, expected_shape # Reverse the input shapes since broadcasting should be symmetric. yield assert_shapes_correct, input_shapes[::-1], expected_shape
def test_two_compatible_by_prepending_ones_input_shapes(): """ Check that two different input shapes (of different lengths) broadcast to the correct shape. """ data = [ [[(), (3,)], (3,)], [[(3,), (3,3)], (3,3)], [[(3,), (3,1)], (3,3)], [[(1,), (3,3)], (3,3)], [[(), (3,3)], (3,3)], [[(1,1), (3,)], (1,3)], [[(1,), (3,1)], (3,1)], [[(1,), (1,3)], (1,3)], [[(), (1,3)], (1,3)], [[(), (3,1)], (3,1)], [[(), (0,)], (0,)], [[(0,), (0,0)], (0,0)], [[(0,), (0,1)], (0,0)], [[(1,), (0,0)], (0,0)], [[(), (0,0)], (0,0)], [[(1,1), (0,)], (1,0)], [[(1,), (0,1)], (0,1)], [[(1,), (1,0)], (1,0)], [[(), (1,0)], (1,0)], [[(), (0,1)], (0,1)], ] for input_shapes, expected_shape in data: yield assert_shapes_correct, input_shapes, expected_shape # Reverse the input shapes since broadcasting should be symmetric. yield assert_shapes_correct, input_shapes[::-1], expected_shape
def test_incompatible_shapes_raise_valueerror(): """ Check that a ValueError is raised for incompatible shapes. """ data = [ [(3,), (4,)], [(2,3), (2,)], [(3,), (3,), (4,)], [(1,3,4), (2,3,3)], ] for input_shapes in data: yield assert_incompatible_shapes_raise, input_shapes # Reverse the input shapes since broadcasting should be symmetric. yield assert_incompatible_shapes_raise, input_shapes[::-1]
def test_same_as_ufunc(): """ Check that the data layout is the same as if a ufunc did the operation. """ data = [ [[(1,), (3,)], (3,)], [[(1,3), (3,3)], (3,3)], [[(3,1), (3,3)], (3,3)], [[(1,3), (3,1)], (3,3)], [[(1,1), (3,3)], (3,3)], [[(1,1), (1,3)], (1,3)], [[(1,1), (3,1)], (3,1)], [[(1,0), (0,0)], (0,0)], [[(0,1), (0,0)], (0,0)], [[(1,0), (0,1)], (0,0)], [[(1,1), (0,0)], (0,0)], [[(1,1), (1,0)], (1,0)], [[(1,1), (0,1)], (0,1)], [[(), (3,)], (3,)], [[(3,), (3,3)], (3,3)], [[(3,), (3,1)], (3,3)], [[(1,), (3,3)], (3,3)], [[(), (3,3)], (3,3)], [[(1,1), (3,)], (1,3)], [[(1,), (3,1)], (3,1)], [[(1,), (1,3)], (1,3)], [[(), (1,3)], (1,3)], [[(), (3,1)], (3,1)], [[(), (0,)], (0,)], [[(0,), (0,0)], (0,0)], [[(0,), (0,1)], (0,0)], [[(1,), (0,0)], (0,0)], [[(), (0,0)], (0,0)], [[(1,1), (0,)], (1,0)], [[(1,), (0,1)], (0,1)], [[(1,), (1,0)], (1,0)], [[(), (1,0)], (1,0)], [[(), (0,1)], (0,1)], ] for input_shapes, expected_shape in data: yield assert_same_as_ufunc, input_shapes[0], input_shapes[1] # Reverse the input shapes since broadcasting should be symmetric. yield assert_same_as_ufunc, input_shapes[1], input_shapes[0] # Try them transposed, too. yield assert_same_as_ufunc, input_shapes[0], input_shapes[1], True # ... and flipped for non-rank-0 inputs in order to test negative # strides. if () not in input_shapes: yield assert_same_as_ufunc, input_shapes[0], input_shapes[1], False, True yield assert_same_as_ufunc, input_shapes[0], input_shapes[1], True, True
if __name__ == "__main__": run_module_suite()
|