I have an array like this:
a = np.array(
    [[
        [
            [0.        , 0.        ],
            [0.        , 0.        ],
            [0.        , 0.        ]
        ],
        [
            [0.07939843, 0.13330124],
            [0.20078699, 0.17482429],
            [0.49948007, 0.52375553]
        ]
    ]]
)
a.shape
>>> (1, 2, 3, 2)
Now I need to compare the values in the last dimension and replace the maximum with 1 and the minimum with 0. So I need a result like this:
array([[[[1., 0.],
          [1., 0.],
          [1., 0.]],
         [[0., 1.],
          [1., 0.],
          [0., 1.]]]])
Well, I have tried a nested loop like this:
for i in range(a.shape[0]):
    for j in range(a.shape[1]):
        for k in range(a.shape[2]):
            if a[i, j, k, 0] > a[i, j, k, 1]:
                a[i, j, k, 0] = 1
                a[i, j, k, 1] = 0
            else:
                a[i, j, k, 0] = 0
                a[i, j, k, 1] = 1
However, I need a faster way, maybe a built-in function of NumPy library.
 
    