Update Distributed Tensorflow Training Page
This Page should be updated to include a more generic example of distributed training with tensorflow:
- There exists a SlurmClusterResolver which should be able to correctly configure the distributed strategy (has to be tested!)
- Here is a script which does also setup the correct
TF_CONFIG
environment. It is, however, not yet properly tested and has a dependency onmpi4py
. It should not be directly be copied but can be basis of a more generic solution which has yet to be developed.
Script
import os, sys, re
import json
import hostlist
import logging
from mpi4py import MPI
def main():
logging.basicConfig(level=logging.DEBUG)
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
set_environment(rank=rank,size=size,color=rank,comm_dist_1= comm)
def clean_environment():
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ.pop('TF_CONFIG', None)
def set_environment(rank: int, size: int, color: int, comm_dist_1):
# first clean environment
clean_environment()
# smallest port we want to use
port_base = 33000
# calculate offset
my_port = port_base + rank
# get hostnames from slurm
hostnames = hostlist.expand_hostlist(os.environ['SLURM_JOB_NODELIST'])
# get local node id to determine the node this process is running on
local_nodeID = int(os.environ["SLURM_NODEID"])
# assemble local string
my_config = f'{hostnames[local_nodeID]}:{my_port}'
logging.debug(f'[{rank:03d}]:{my_config}')
# comm_dist_1 is a communicator where all ranks are included
global_conf = comm_dist_1.allgather(my_config)
logging.debug(global_conf)
assert global_conf[color] == my_config
tf_config = {
"cluster": {
"worker": global_conf,
},
"task": {"index": color, "type": "worker"}
}
os.environ['TF_CONFIG'] = json.dumps(tf_config)
logging.debug(os.environ["TF_CONFIG"])
return
if __name__ == "__main__":
main()