Add Metadata To Tensorflow Frozen Graph Pb
Solution 1:
First of all, yes you should use the new SavedModel format, as it is what will be supported by the TF team going forwards, and works with Keras as well. You can add an additional endpoint to the model, that returns a constant tensor (as you mention) with a string of your XML data.
This is good because it's hermetic -- the underlying savemodel format does not matter, because your metadata is saved in the computation graph itself.
See the answer to this question: Saving a TF2 keras model with custom signature defs . That answer doesn't get you 100% of the way there for Keras, because it doesn't interop nicely with the tf.keras.models.load function, as they wrap it inside a tf.Module
. Luckily, using tf.keras.Model
works as well in TF2, if you add a tf.function decorator:
classMyModel(tf.keras.Model):
def__init__(self, metadata, **kwargs):
super(MyModel, self).__init__(**kwargs)
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.metadata = tf.constant(metadata)
defcall(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
@tf.function(input_signature=[])defget_metadata(self):
return self.metadata
model = MyModel('metadata_test')
input_arr = tf.random.uniform((5, 5, 1)) # This call is needed so Keras knows its input shape. You could define manually too
outputs = model(input_arr)
Then you can save and load your model as follows:
tf.keras.models.save_model(model, 'test_model_keras')
model_loaded = tf.keras.models.load_model('test_model_keras')
And finally use model_loaded.get_metadata()
to retrieve your constant metadata tensor.
Post a Comment for "Add Metadata To Tensorflow Frozen Graph Pb"