To run a Dask cluster on Kubernetes capable of GPU compute you need the following:
- Kubernetes nodes need GPUs and drivers. This can be set up with the NVIDIA k8s device plugin.
 
- Scheduler and worker pods will need a Docker image with NVIDIA tools installed. As you suggest the RAPIDS images are good for this.
 
- The pod container spec will need GPU resources such as 
resources.limits.nvidia.com/gpu: 1 
- The Dask workers needs to be started with the 
dask-cuda-worker command from the dask_cuda package (which is included in the RAPIDS images). 
Note: For Dask Gateway your container image also needs the dask-gateway package to be installed. We can configure this to be installed at runtime but it's probably best to create a custom image with this package installed.
Therefore here is a minimal Dask Gateway config which will get you a GPU cluster.
# config.yaml
gateway:
  backend:
    image:
      name: rapidsai/rapidsai
      tag: cuda11.0-runtime-ubuntu18.04-py3.8  # Be sure to match your k8s CUDA version and user's Python version
    worker:
      extraContainerConfig:
        env:
          - name: EXTRA_PIP_PACKAGES
            value: "dask-gateway"
        resources:
          limits:
            nvidia.com/gpu: 1  # This could be >1, you will get one worker process in the pod per GPU
    scheduler:
      extraContainerConfig:
        env:
          - name: EXTRA_PIP_PACKAGES
            value: "dask-gateway"
        resources:
          limits:
            nvidia.com/gpu: 1  # The scheduler requires a GPU in case of accidental deserialisation
  extraConfig:
    cudaworker: |
      c.ClusterConfig.worker_cmd = "dask-cuda-worker"
We can test things work by launching Dask gateway, creating a Dask cluster and running some GPU specific work. Here is an example where we get the NVIDIA driver version from each worker.
$ helm install dgwtest daskgateway/dask-gateway -f config.yaml
In [1]: from dask_gateway import Gateway
In [2]: gateway = Gateway("http://dask-gateway-service")
In [3]: cluster = gateway.new_cluster()
In [4]: cluster.scale(1)
In [5]: from dask.distributed import Client
In [6]: client = Client(cluster)
In [7]: def get_nvidia_driver_version():
   ...:     import pynvml
   ...:     return pynvml.nvmlSystemGetDriverVersion()
   ...: 
In [9]: client.run(get_nvidia_driver_version)
Out[9]: {'tls://10.42.0.225:44899': b'450.80.02'}