# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an 'AS IS' BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for ClusterPreserveQuantizeRegistry.""" import tensorflow as tf from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry from tensorflow_model_optimization.python.core.keras.compat import keras from tensorflow_model_optimization.python.core.quantization.keras import quantize_config from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve import cluster_preserve_quantize_registry from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry QuantizeConfig = quantize_config.QuantizeConfig layers = keras.layers class ClusterPreserveQuantizeRegistryTest(tf.test.TestCase): def setUp(self): super(ClusterPreserveQuantizeRegistryTest, self).setUp() # Test CQAT by default self.cluster_preserve_quantize_registry = ( cluster_preserve_quantize_registry.ClusterPreserveQuantizeRegistry( False) ) # layers which are supported # initial and build a Conv2D layer self.layer_conv2d = layers.Conv2D(10, (2, 2)) self.layer_conv2d.build((2, 2)) # initial and build a Dense layer self.layer_dense = layers.Dense(10) self.layer_dense.build((2, 2)) # initial and build a ReLU layer self.layer_relu = layers.ReLU() self.layer_relu.build((2, 2)) # a layer which is not supported # initial and build a Custom layer self.layer_custom = self.CustomLayer() self.layer_custom.build() class CustomLayer(layers.Layer): """A simple custom layer with training weights.""" def build(self, input_shape=(2, 2)): self.add_weight(shape=input_shape, initializer='random_normal', trainable=True) class CustomQuantizeConfig(QuantizeConfig): """A dummy concrete class for testing unregistered configs.""" def get_weights_and_quantizers(self, layer): return [] def get_activations_and_quantizers(self, layer): return [] def set_quantize_weights(self, layer, quantize_weights): pass def set_quantize_activations(self, layer, quantize_activations): pass def get_output_quantizers(self, layer): return [] def get_config(self): return {} def testSupportsKerasLayer(self): # test registered layer self.assertTrue( self.cluster_preserve_quantize_registry.supports(self.layer_dense)) self.assertTrue( self.cluster_preserve_quantize_registry.supports(self.layer_conv2d)) # test layer without training weights self.assertTrue( self.cluster_preserve_quantize_registry.supports(self.layer_relu)) def testDoesNotSupportCustomLayer(self): self.assertFalse( self.cluster_preserve_quantize_registry.supports(self.layer_custom)) def testApplyClusterPreserveWithQuantizeConfig(self): (self.cluster_preserve_quantize_registry .apply_cluster_preserve_quantize_config( self.layer_conv2d, default_8bit_quantize_registry.Default8BitConvQuantizeConfig( ['kernel'], ['activation'], False))) def testRaisesErrorUnsupportedQuantizeConfigWithLayer(self): with self.assertRaises( ValueError, msg='Unregistered QuantizeConfigs should raise error.'): (self.cluster_preserve_quantize_registry. apply_cluster_preserve_quantize_config( self.layer_conv2d, self.CustomQuantizeConfig)) with self.assertRaises(ValueError, msg='Unregistered layers should raise error.'): (self.cluster_preserve_quantize_registry. apply_cluster_preserve_quantize_config( self.layer_custom, self.CustomQuantizeConfig)) class ClusterPreserveDefault8bitQuantizeRegistryTest(tf.test.TestCase): def setUp(self): super(ClusterPreserveDefault8bitQuantizeRegistryTest, self).setUp() self.default_8bit_quantize_registry = ( default_8bit_quantize_registry.Default8BitQuantizeRegistry()) self.cluster_registry = clustering_registry.ClusteringRegistry() # Test CQAT by default self.cluster_preserve_quantize_registry = ( cluster_preserve_quantize_registry.ClusterPreserveQuantizeRegistry( False)) def testSupportsClusterDefault8bitQuantizeKerasLayers(self): # ClusterPreserveQuantize supported layer, must be suppoted # by both Cluster and Quantize cqat_layers_config_map = ( self.cluster_preserve_quantize_registry._LAYERS_CONFIG_MAP) for cqat_support_layer in cqat_layers_config_map: if cqat_layers_config_map[cqat_support_layer].weight_attrs and ( cqat_layers_config_map[cqat_support_layer].quantize_config_attrs): self.assertIn( cqat_support_layer, self.cluster_registry._LAYERS_WEIGHTS_MAP, msg='Clusteirng doesn\'t support {}'.format(cqat_support_layer)) self.assertIn( cqat_support_layer, self.default_8bit_quantize_registry._layer_quantize_map, msg='Default 8bit QAT doesn\'t support {}'.format( cqat_support_layer)) if __name__ == '__main__': tf.test.main()