38 lines
		
	
	
		
			1.0 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			38 lines
		
	
	
		
			1.0 KiB
		
	
	
	
		
			Python
		
	
	
	
| """Convert a saved model to tflite model.
 | |
| 
 | |
| Usage: python3 saved-model-to-tflite.py <mlgo saved_model_dir> <tflite dest_dir>
 | |
| 
 | |
| The <tflite dest_dir> will contain:
 | |
|   model.tflite: this is the converted saved model
 | |
|   output_spec.json: the output spec, copied from the saved_model dir.
 | |
| """
 | |
| 
 | |
| import tensorflow as tf
 | |
| import os
 | |
| import sys
 | |
| from tf_agents.policies import greedy_policy
 | |
| 
 | |
| 
 | |
| def main(argv):
 | |
|   assert len(argv) == 3
 | |
|   sm_dir = argv[1]
 | |
|   tfl_dir = argv[2]
 | |
|   tf.io.gfile.makedirs(tfl_dir)
 | |
|   tfl_path = os.path.join(tfl_dir, 'model.tflite')
 | |
|   converter = tf.lite.TFLiteConverter.from_saved_model(sm_dir)
 | |
|   converter.target_spec.supported_ops = [
 | |
|     tf.lite.OpsSet.TFLITE_BUILTINS,
 | |
|   ]
 | |
|   tfl_model = converter.convert()
 | |
|   with tf.io.gfile.GFile(tfl_path, 'wb') as f:
 | |
|     f.write(tfl_model)
 | |
|   
 | |
|   json_file = 'output_spec.json'
 | |
|   src_json = os.path.join(sm_dir, json_file)
 | |
|   if tf.io.gfile.exists(src_json):
 | |
|     tf.io.gfile.copy(src_json,
 | |
|                      os.path.join(tfl_dir, json_file))
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|   main(sys.argv)
 |