Merge tensorflow models

Posted by chunyang on December 27, 2020

Sometimes you want to transfer certain weights from multiple models into a single model or just want to merge multiple models. There are at least two ways to do that:


You can refer to tf.train.init_from_checkpoint.

    ckpt_dir_or_file, assignment_map
  • ckpt_dir_or_file
    • Can be a checkpoint directory or file: /path/to/checkpoint_dir, /path/to/checkpoint_dir/model-1234
    • Can be a saved model: /path/to/saved_model/variables/variables (The second variables is the name prefix of variables.index)
  • assignment_map
    • key: can be a scope
    • value: can be variable name, variable reference

It is very flexible. Under the hood, init_from_checkpoint modifies the initializer of a variable. When we run tf.global_variables_initializer(), the related restore op will be executed.

variable._initializer_op = init_op
variable._initial_value = restore_op

If you have user-defined variables such as you create a AwesomeVariable which behaves like a tensorflow Variable but with a different back-end storage. You can define a similar function by creating the user-defined initializer_op.

# Create your op using

# Replace the initializer op
variable._initializer_op = init_op

Multiple init_from_checkpoint can be called with different ckpt_dir_or_file. As a result, a single model’s variables can be initialized from different source checkpoints or saved models.

import os
import tensorflow as tf

os.makedirs("./models/a", exist_ok=True)
os.makedirs("./models/b", exist_ok=True)

with tf.Session(graph=tf.Graph()) as session:
    tf.Variable(3, name="a")
    saver = tf.train.Saver(), "./models/a/model-a")

with tf.Session(graph=tf.Graph()) as session:
    tf.Variable(4, name="b")
    saver = tf.train.Saver(), "./models/b/model-b")

with tf.Session(graph=tf.Graph()) as session:
    a = tf.Variable(1, name="a")
    b = tf.Variable(1, name="b")
    tf.train.init_from_checkpoint("./models/a/model-a", {"a": a})
    tf.train.init_from_checkpoint("./models/b/model-b", {"b": b})


When tensorflow loads a model, the only requirement is that: all variables in the model must have a valid value in the checkpoint. So we can first merge checkpoints of different models and then just load once.

 import tensorflow as tf
from tensorflow.python.ops import gen_io_ops

src = tf.constant(["./models/a/model-a", "./models/b/model-b"])
target = tf.constant("./models/merged_model")

op = gen_io_ops.merge_v2_checkpoints(src, target, delete_old_dirs=False)


with tf.Session(graph=tf.Graph()) as session:

    a = tf.Variable(1, name="a")
    b = tf.Variable(1, name="b")
    saver = tf.train.Saver()
    saver.restore(session, "./models/merged_model")

Be careful that delete_old_dirs will delete the .index and .data file no matter it is set to True or False.