Estimator a tutorial

Posted by chunyang on September 3, 2020

介绍 Estiamtor 相关知识。

背景

Tensorflow 在 TensorFlow Estimators: Managing Simplicity vs. Flexibility in High-Level Machine Learning Frameworks 文章中阐述了其在 Tensorflow 基础之上给用户做的一层抽象。主要是对用户屏蔽掉:

  • Session 的创建
  • 分布式相关的逻辑:
    • 包括组网和相关的 Server 构建

本篇文章就是详细解释 Estimator 的具体工作原理。

非 Estimator 基于 Parameter server 架构的分布式学习

本篇文章主要关注的是基于 Parameter server 的数据并行方式下的分布式计算学习。传统的分布式学习大致的逻辑如下:

  • 对于 Parameter server:
task_index = 0
ps_hosts = ["a:1001", "b:1011"]
worker_hosts = ["a:1002", "c:1003"]

cluster_def = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

# PS server 监听所有 ps 和 worker
session_config = tf.ConfigProto(
    device_filters=["/job:ps", "/job:worker"],
)

server = tf.train.Server(
    cluster_def,
    job_name="ps",
    task_index=task_index,
    config=session_config,
)
server.join()
  • 对于 worker
task_index = 0
ps_hosts = ["a:1001", "b:1011"]
worker_hosts = ["a:1002", "c:1003"]

cluster_def = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

# PS server 监听所有 ps 和 worker 自身
session_config = tf.ConfigProto(
    device_filters=["/job:ps", "/job:worker/task:%s"%task_index],
)

server = tf.train.Server(
    cluster_def,
    job_name="worker",
    task_index=task_index,
    config=session_config,
)

# the server instance will be passed to other functions,
# such as `tf.train.MonitoredTrainingSession`

用户会在上述基础上继续构造数据的 pipeline,构造模型,训练(前向和反向),评估模型,最终导出模型。算法工程师宝贵的时间除了用在建模上,用户需要做很多的工作,很多重复的工作。Estimator 的出现就是期望对用户屏蔽掉更多的底层细节,加速算法的研发和迭代。

Estimator

先看一下 Estimator 的大图。Estimator 主要对外暴露 3 个行为:

  • train
  • evaluate
  • predict

image.png

千变万化,最终的接口都会调用到这 3 个接口上,这 3 个接口。这 3 个接口主要对应 3 个 mode:

  • tf.estimator.ModeKeys.TRAIN
  • tf.estimator.ModeKeys.PREDICT
  • tf.estimator.ModeKeys.EVAL

在不同模式下,对返回的 EstimatorSpec 有不同的要求。

行为控制

主要控制逻辑都依赖 Hooks 。大概有 4 种 hooks

  • training_chief_hooks
  • training_hooks
  • evaluation_hooks
  • prediction_hooks

具体接口见:链接

  • begin()
  • after_create_session(session, coord)
  • before_run(run_context)
  • after_run(run_context, run_values)
  • end(session)

构造 Estimator

先看一个最简单的: y = Wx + b 的线性回归的例子。

import tensorflow as tf

"""
Estimator interface
tf.estimator.Estimator(
    model_fn,
    model_dir=None,
    config=None,
    params=None,
    warm_start_from=None,
)
"""


class MyEstimator(tf.estimator.Estimator):
    """MyEstimator"""

    def __init__(self, model_dir, config=None, params=None):
        super(MyEstimator, self).__init__(
            self.model_fn,
            model_dir=model_dir,
            config=config,
            params=params,
        )

    def model_fn(self, features, labels, mode, config):
        # 具体的含义见
        # https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#args
        optimizer = tf.train.AdamOptimizer()
        x = features["x"]
        w = tf.Variable(0.1, name="x")
        b = tf.Variable(0.1, name="b")
        prediction = w * x + b
        print("Mode = ", mode)
        if mode == tf.estimator.ModeKeys.PREDICT:
            return tf.estimator.EstimatorSpec(mode, predictions=prediction)

        loss = tf.losses.mean_squared_error(labels, prediction)
        train_op = optimizer.minimize(
            loss, global_step=tf.train.get_or_create_global_step()
        )
        if mode == tf.estimator.ModeKeys.EVAL:
            metrics = {
                "mse": tf.metrics.mean_squared_error(labels, prediction)
            }
            return tf.estimator.EstimatorSpec(
                mode,
                predictions=prediction,
                eval_metric_ops=metrics,
                loss=loss,
            )

        if mode == tf.estimator.ModeKeys.TRAIN:
            return tf.estimator.EstimatorSpec(
                mode, predictions=prediction, loss=loss, train_op=train_op,
            )

        raise ValueError("Not a valid mode: {}".format(mode))

假设上述模型保存为: model.py。在不同 mode 下,其返回的是一个 tf.estimator.EstimatorSpec

tf.estimator.EstimatorSpec(
    mode, predictions=None, loss=None, train_op=None, eval_metric_ops=None,
    export_outputs=None, training_chief_hooks=None, training_hooks=None,
    scaffold=None, evaluation_hooks=None, prediction_hooks=None
)

使用创建的 Estimator

import logging
import os
import random
import subprocess

import tensorflow as tf

from model import MyEstimator


logging.getLogger().setLevel(logging.INFO)

model_dir = "/tmp/temp_model_dir/"
subprocess.check_call("rm -rf %s" % model_dir, shell=True)

estimator = MyEstimator(model_dir)

batch_size = 1

def train_input_fn():
    def generator():
        for _ in range(10):
            datum = random.random()
            yield "\t".join(map(str, (datum, datum * 0.8 + 1)))

    def parse(line):
        fields = tf.decode_csv(line, [[0.0], [0.0]], field_delim="\t")
        return {"x": fields[0]}, fields[1]

    dataset = tf.data.Dataset.from_generator(
        generator, tf.string, tf.TensorShape([])
    )
    dataset = dataset.map(parse)
    return dataset.batch(batch_size)


def serving_input_fn():
    feature_tensors = {
        "x": tf.placeholder(tf.float32, shape=(None, 1), name="input_x")
    }
    receiver_tensor = tf.placeholder(
        tf.float32, shape=(None, 1), name="output_tensor"
    )
    return tf.estimator.export.ServingInputReceiver(
        feature_tensors, receiver_tensor
    )


def predict_input_fn():
    def generator():
        for _ in range(10):
            datum = random.random()
            yield "\t".join(map(str, (datum,)))

    def parse(line):
        fields = tf.decode_csv(line, [[0.0]], field_delim="\t")
        return {"x": fields[0]}

    dataset = tf.data.Dataset.from_generator(
        generator, tf.string, tf.TensorShape([])
    )
    dataset = dataset.map(parse)
    return dataset.batch(batch_size)


estimator.train(train_input_fn)
estimator.evaluate(train_input_fn)
base = os.path.join(model_dir, "test")
result_dir = estimator.export_savedmodel(base, serving_input_fn)
print("Result dir: ", result_dir)

for data in estimator.predict(predict_input_fn):
    print(data)

上述文件保存为 main.py。 python main.py 就可以体验下整体的流程。包括训练,验证,打分。
通过上述示例我们可以看到,如果只使用 Estimator 的有限的接口,可以不用操心:

  • session 的创建
  • 导出 savedmodel 时也不用手动创建 SavedModelBundler

算法工程师转而需要最关心的是:

  • 数据怎么生成:相关的 Input_fn
    • serving_input_fn: Stackoverflow
    • features: model_fn 的输入 placeholders
    • receiver_tensors: 模型的输入 placeholders,通过解析后得到 features 相关的
  • 模型怎么构建:model.py 中的 model_fn

如果拆开具体的 train/evaluate/predict,其内部本质还是会去创建 Session

  • train – tf.train.MonitoredTrainingSession 
  • evaluate – MonitoredSession 
    • 不知道为什么要使用 tensorflow.python.training.evaluation 这个模块来完成 evaluate。因为现在 estimator 大部分代码都开始从 tensorflow 中剥离
  • predict – tf.train.MonitoredSession 

分布式 Estimator

即使单机可以在内存中存放所有的模型参数,巨大的样本量也会让单机训练逊色。在海量数据的前提下,更多是基于 Parameter Server 的进行的数据并行训练。Tensorflow 较高的版本开始推广 distribute.strategy 。本文不探讨这个,还是基于传统地基于组网信息来进行的分布式训练。

角色

在非 Estimator 模式下,我们只有 ps 和 worker 两种角色。在 Estimator 模式下,会多另外三个角色:

  • master: deprecated 官方说不官方支持
    • master 节点现在做两件事:worker 0 角色和 evaluator 角色
    • master 单节点承担过多的角色
  • chief:类似于传统模式下的 worker 0
  • evaluator:单独的模型验证节点
    • 这个角色会监听 checkpoint 目录,当有新的 checkpoint 产出时,evaluator 会从 checkpoint 恢复参数,从 eval_input_fn 中获取数据进行打分,然后计算 eval_metric_ops 中的值。用户根据结果来判断是否需要导出

所以,在分布式场景中,整个网络中的角色有 4 种:ps, worker,chief 和 evaluator。

  • ps,worker 和传统的分布式一致
  • chief 充当 worker 0 的角色。但是这个时候 worker 的 task_index 仍然是从 0 开始。不过这里的 worker-0 已经没有特殊的作用。
  • evaluator 启动后,负责监听 model_dir 下面的 checkpoint 产出。

驱动分布式训练

import tensorflow as tf
tf.estimator.train_and_evaluate(estimator_instance, train_spec, eval_spec)

主要是调用这个接口来驱动 estimator 的训练。框架会根据 train_spec 和 eval_spec 的内容来控制整个模型的训练流程。多有角色都统一调用此接口,这个接口内部调用到具体的逻辑:Github

image.png

tf.estimator.TrainSpec

tf.estimator.TrainSpec(
    input_fn, max_steps=None, hooks=None, saving_listeners=None
)

TrainSpec 的内容:

  • input_fn: 和之前一致,产出 model_fn 需要的数据内容
  • max_steps:是否提前结束任务
  • hooks: 派生自 tf.estimator.SessionRunHook
  • saving_listeners

tf.estimator.EvalSpec

tf.estimator.EvalSpec(
    input_fn, steps=100, name=None, hooks=None, exporters=None,
    start_delay_secs=120, throttle_secs=600
)

EvalSpec 的内容:

  • input_fn:和之前一致
  • steps:提前结束 evaluate
  • hooks: 派生自 tf.estimator.SessionRunHook
  • exporters:estimator 有提供一些导出的策略控制
    • 例如 BestExporter,派生自 tf.estimator.Exporter

完整例子

import argparse
import json
import logging
import os
import random
import sys
import subprocess

import tensorflow as tf

from model import MyEstimator


logging.getLogger().setLevel(logging.INFO)

model_dir = "/tmp/temp_model_dir/"
subprocess.check_call("rm -rf %s" % model_dir, shell=True)


batch_size = 1
train_number = 1000
test_number = 100

def input_fn(data_size):
    def actual_input_fn():
        def generator():
            for _ in range(data_size):
                datum = random.random()
                yield "\t".join(map(str, (datum, datum * 0.8 + 1)))

        def parse(line):
            fields = tf.decode_csv(line, [[0.0], [0.0]], field_delim="\t")
            return {"x": fields[0]}, fields[1]

        dataset = tf.data.Dataset.from_generator(
            generator, tf.string, tf.TensorShape([])
        )
        dataset = dataset.map(parse)
        return dataset.batch(batch_size)
    return actual_input_fn


def serving_input_fn():
    feature_tensors = {
        "x": tf.placeholder(tf.float32, shape=(None, 1), name="input_x")
    }
    receiver_tensor = tf.placeholder(
        tf.float32, shape=(None, 1), name="output_tensor"
    )
    return tf.estimator.export.ServingInputReceiver(
        feature_tensors, receiver_tensor
    )

train_spec = tf.estimator.TrainSpec(
    input_fn(train_number), max_steps=500, hooks=None
)
eval_spec = tf.estimator.EvalSpec(
    input_fn(test_number), steps=50, name=None, hooks=None, exporters=None,
    start_delay_secs=0, throttle_secs=0
)

def get_cluster(args):
    """get_cluster"""
    cluster = {
        "cluster": {
            "ps": args.ps_hosts.split(";"),
            "worker": args.worker_hosts.split(";"),
            "chief": args.chief_hosts.split(";"),
        },
        "task": {
            "type": args.worker_type,
            "index": args.worker_index,
        }
    }
    os.environ["TF_CONFIG"] = json.dumps(cluster)

parser = argparse.ArgumentParser()
parser.add_argument("--ps-hosts")
parser.add_argument("--worker-hosts")
parser.add_argument("--chief-hosts")
parser.add_argument("--evaluator")
parser.add_argument("--worker-type", type=str)
parser.add_argument("--worker-index", type=int)

print("Argv: ", sys.argv)
args, _ = parser.parse_known_args()

get_cluster(args)

estimator = MyEstimator(model_dir)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

上述文件保存为 main.py

run-dist.sh

#!/bin/sh

file=main.py

mkdir -p logs
FILE=logs/pid.file
if [ -f ${FILE} ]
then
    for i in `awk '{print $NF}' ${FILE}`
    do
        kill -9 $i
    done
fi

\rm -rf logs/*

function get_port() {
    local avaiable_port=$(python -c \
        'from __future__ import print_function;\
        import socket; s = socket.socket(); s.bind(("", 0)); \
        print(s.getsockname()[1])')
    echo $avaiable_port
}


function get_host() {
    size=$1
    hosts=""
    PORT=$(get_port)
    for i in `seq ${size}`
    do
        if [ -z "${hosts}" ]
        then
            hosts="localhost:"${PORT}
        else
            hosts=${hosts}";localhost:"${PORT}
        fi
        PORT=$(get_port)
    done

    echo ${hosts}
}

function start_tasks() {
    type=$1
    size=$2
    echo "Start ${type}, number: ${size}"
    ((size-=1))
    for i in `seq 0 ${size}`
    do
        index=$i
        python ${file} \
            --chief-hosts ${chief_hosts} \
            --evaluator-hosts ${evaluator_hosts} \
            --ps-hosts ${ps_hosts} \
            --worker-hosts ${worker_hosts} \
            --worker-type ${type} --worker-index ${index} &> logs/${type}.log.$i &
        echo "${type}: "${i}" pid= "$! >> logs/pid.file
    done

}

PS_SIZE=1
WORKER_SIZE=2
CHIEF_SIZE=1
EVALUATOR_SIZE=1
ps_hosts=$(get_host ${PS_SIZE})

worker_hosts=$(get_host ${WORKER_SIZE})
chief_hosts=$(get_host ${CHIEF_SIZE})
evaluator_hosts=$(get_host ${EVALUATOR_SIZE})

echo "ps = "${ps_hosts}
echo "worker = "${worker_hosts}
echo "chief = "${chief_hosts}
echo "evaluator = "${evaluator_hosts}
start_tasks "ps" ${PS_SIZE}

echo "Sleep 3s before start worker"
sleep 3s

start_tasks "worker" ${WORKER_SIZE}
start_tasks "evaluator" ${EVALUATOR_SIZE}

type="chief"
index=0

python ${file} \
    --chief-hosts ${chief_hosts} \
    --evaluator-hosts ${evaluator_hosts} \
    --ps-hosts ${ps_hosts} \
    --worker-hosts ${worker_hosts} \
    --worker-type ${type} --worker-index ${index} &> logs/chief.log.$i

组网训练

既然可以一键驱动分布式训练,那么 estimator 自身是如何识别自身角色,并且执行对应的逻辑呢?

组网信息

组网信息依赖环境变量 TF_CONFIG 。

import json
import os

ps_hosts = ["a:1001", "b:1002"]
worker_hosts = ["a:1003", "b:1004"]
chief_hosts = ["a:1004", "b:1003"]

# 对于 ps, worker,chief
## worker task index 从 0 开始
## Evaluator 不能出现在 cluster 中
cluster = {
    "cluster": {"ps": ps_hosts, "worker": worker_hosts, "chief": chief_hosts},
    "task": {
        "index": 0,
        "type": "worker",  # ps, chief, worker
    }
}

# 对于 evaluator 的 cluster,大概如下:
## 当前只能有一个 evaluator
cluster = {
    "cluster": {"ps": ps_hosts, "worker": worker_hosts, "chief": chief_hosts},
    "task": {
        "index": 0,
        "type": "evaluator",  # ps, chief, evaluator
    }
}

os.environ["TF_CONFIG"] = json.dumps(cluster)

执行逻辑

主要执行逻辑在:链接
tf.estimator.RunConfig 在构造的时候会从 TF_CONFIG 中去解析,然后找到正确的逻辑,最后执行如下逻辑之一:

  • run_ps
  • run_worker
  • run_chief
  • run_master
    • 会额外启动一个 Evaluator
  • run_evaluator

问题

TFOperator

社区的 TFOperator 组网信息是 deprecated 的 master + ps + worker。这种会存在 master 任务过重的问题。虽然它是启动一个子线程来进行模型验证。但是是单机加载模型,容易受内存影响。Estimator 本质是根据 TF_CONFIG 来判断的,所以我们只要在启动 Estimator 前更改掉这个变量即可。

多个角色之间同步问题

Evaluator 是单独启动的,它只是监听 model_dir 是否有新的 checkpoint 产出,并且进行验证。所有 evaluator 的退出过早会导致模型没有验证完,所以需要在退出时有某种同步。例如 chief 产出模型后,需要确认其产出的 checkpoint 确实被验证。

Evaluator 不退

Evaluator 现在唯一的退出条件是 global_step > max_steps。所以 max_steps 设置的不合理,不加同步控制的话, Evaluator 也不会主动退出。而且如果 evaluator 主动退出,也会导致新产出的 checkpoint 没有得到验证

分布式的 prediction

Estimator 支持分布式的训练和验证。但是现在打分逻辑并没有分布式化。可以参考这里的回答:Stackoverflow.
其核心思想是我们仍然尝试去复用 estimator 的中的部分逻辑,但是在创建 session 时,需要创建 MonitoredTrainingSession。这样就可以依赖 checkpoint 路径自动去加载模型。

  • 手动启动 server:如果不启动的话,会出现假死的现象
  • 重新覆盖掉 estimator 的 predict
import argparse
import json
import logging
import os
import random
import sys
import subprocess

import six
import tensorflow as tf
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.training import server_lib
from tensorflow.python.training import training
from tensorflow.python.framework import random_seed
from tensorflow.python.eager import context
from tensorflow_estimator.python.estimator import model_fn as model_fn_lib
from tensorflow_estimator.python.estimator import estimator
from tensorflow.python.framework import ops

from model import MyEstimator


logging.getLogger().setLevel(logging.INFO)

model_dir = "/tmp/temp_model_dir/"


batch_size = 1
train_number = 1000
test_number = 100

def input_fn(data_size):
    def actual_input_fn():
        def generator():
            for _ in range(data_size):
                datum = random.random()
                yield "\t".join(map(str, (datum, datum * 0.8 + 1)))

        def parse(line):
            fields = tf.decode_csv(line, [[0.0], [0.0]], field_delim="\t")
            return {"x": fields[0]}, fields[1]

        dataset = tf.data.Dataset.from_generator(
            generator, tf.string, tf.TensorShape([])
        )
        dataset = dataset.map(parse)
        return dataset.batch(batch_size).make_one_shot_iterator().get_next()
    return actual_input_fn


def get_cluster(args):
    """get_cluster"""
    cluster = {
        "cluster": {
            "ps": args.ps_hosts.split(";"),
            "worker": args.worker_hosts.split(";"),
            "chief": args.chief_hosts.split(";"),
        },
        "task": {
            "type": args.worker_type,
            "index": args.worker_index,
        }
    }
    os.environ["TF_CONFIG"] = json.dumps(cluster)


def run_std_server(config):
    if config.session_config is None:
        session_config = config_pb2.ConfigProto(log_device_placement=False)
    else:
        session_config = config_pb2.ConfigProto(
            log_device_placement=False,
            gpu_options=config.session_config.gpu_options,
        )

        server = server_lib.Server(
            config.cluster_spec,
            job_name=config.task_type,
            task_index=config.task_id,
            config=session_config,
            start=False,
            protocol=config.protocol,
        )
        server.start()
        return server


def hook_predict(args, config):

    # Override estimator predict
    def predict(
        self,
        input_fn,
        predict_keys=None,
        hooks=None,
        checkpoint_dir=None,
        yield_single_examples=True,
    ):
        """Arguments are same with Estimator.predict"""
        with context.graph_mode():
            hooks = estimator._check_hooks_type(hooks)
            # Check that model has been trained.
            if not checkpoint_dir:
                raise ValueError("No checkpoint_dir")
            with ops.Graph().as_default() as g, g.device(self._device_fn):
                random_seed.set_random_seed(self._config.tf_random_seed)
                self._create_and_assert_global_step(g)
                features, input_hooks = self._get_features_from_input_fn(
                    input_fn, model_fn_lib.ModeKeys.PREDICT
                )
                estimator_spec = self._call_model_fn(
                    features,
                    None,
                    model_fn_lib.ModeKeys.PREDICT,
                    self.config,
                )

                predictions = self._extract_keys(
                    estimator_spec.predictions, predict_keys
                )
                all_hooks = list(input_hooks)
                all_hooks.extend(hooks)
                all_hooks.extend(
                    list(estimator_spec.prediction_hooks or [])
                )
                with training.MonitoredTrainingSession(
                    is_chief=args.worker_type=="chief",
                    master=config.master,
                    checkpoint_dir=checkpoint_dir,
                    config=config.session_config,
                ) as mon_sess:

                    while not mon_sess.should_stop():
                        preds_evaluated = mon_sess.run(predictions)
                        if not yield_single_examples:
                            yield preds_evaluated
                        elif not isinstance(predictions, dict):
                            for pred in preds_evaluated:
                                yield pred
                        else:
                            for i in range(
                                self._extract_batch_length(preds_evaluated)
                            ):
                                yield {
                                    key: value[i]
                                    for key, value in six.iteritems(
                                        preds_evaluated
                                    )
                                }
    estimator.Estimator.predict = predict


parser = argparse.ArgumentParser()
parser.add_argument("--ps-hosts")
parser.add_argument("--worker-hosts")
parser.add_argument("--chief-hosts")
parser.add_argument("--evaluator")
parser.add_argument("--worker-type", type=str)
parser.add_argument("--worker-index", type=int)

print("Argv: ", sys.argv)
args, _ = parser.parse_known_args()

get_cluster(args)

user_estimator = MyEstimator(model_dir)

server = run_std_server(user_estimator.config)

if args.worker_type == "ps":
    server.join()
else:
    hook_predict(args, user_estimator.config)
    kwargs = {
        "checkpoint_dir":  model_dir,
    }
    for data in user_estimator.predict(input_fn(10), **kwargs):
        print(data)

#!/bin/sh

killed_exit=$1
file=main_dist.py

mkdir -p logs
FILE=logs/pid.file
if [ -f ${FILE} ]
then
    for i in `awk '{print $NF}' ${FILE}`
    do
        kill -9 $i
    done
fi

[[ ! -z ${killed_exit} ]] && exit 0


\rm -rf logs/*

function get_port() {
    local avaiable_port=$(python -c \
        'from __future__ import print_function;\
        import socket; s = socket.socket(); s.bind(("", 0)); \
        print(s.getsockname()[1])')
    echo $avaiable_port
}

function get_host() {
    size=$1
    hosts=""
    PORT=$(get_port)
    for i in `seq ${size}`
    do
        if [ -z "${hosts}" ]
        then
            hosts="localhost:"${PORT}
        else
            hosts=${hosts}";localhost:"${PORT}
        fi
        PORT=$(get_port)
    done

    echo ${hosts}
}

function start_tasks() {
    type=$1
    size=$2
    echo "Start ${type}, number: ${size}"
    ((size-=1))
    for i in `seq 0 ${size}`
    do
        index=$i
        python ${file} \
            --chief-hosts ${chief_hosts} \
            --evaluator-hosts ${evaluator_hosts} \
            --ps-hosts ${ps_hosts} \
            --worker-hosts ${worker_hosts} \
            --worker-type ${type} --worker-index ${index} &> logs/${type}.log.$i &
        echo "${type}: "${i}" pid= "$! >> logs/pid.file
    done

}

PS_SIZE=1
WORKER_SIZE=2
CHIEF_SIZE=1
EVALUATOR_SIZE=1
ps_hosts=$(get_host ${PS_SIZE})

worker_hosts=$(get_host ${WORKER_SIZE})
chief_hosts=$(get_host ${CHIEF_SIZE})

echo "ps = "${ps_hosts}
echo "worker = "${worker_hosts}
echo "chief = "${chief_hosts}
start_tasks "ps" ${PS_SIZE}

echo "Sleep 3s before start worker"
sleep 3s

start_tasks "worker" ${WORKER_SIZE}

type="chief"
index=0

python ${file} \
    --chief-hosts ${chief_hosts} \
    --evaluator-hosts ${evaluator_hosts} \
    --ps-hosts ${ps_hosts} \
    --worker-hosts ${worker_hosts} \
    --worker-type ${type} --worker-index ${index} &> logs/chief.log.$i

Github code