jeduardogruiz commited on
Commit
516a027
1 Parent(s): 92ee7ff

Upload 22 files

Browse files
README.md CHANGED
@@ -1,3 +1,2 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
1
+ This directory is modified based on default_8bit, which allows you to manually
2
+ change the number of bits of weight and activation in QAT.
 
__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
botWallet.js ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Código de tu aplicación aquí
4
+
5
+ if __name__ == "__main__":
6
+ os.system("python main.py")
7
+ const TelegramBot = require('node-telegram-bot-api');
8
+ const Web3 = require('web3');
9
+ const web3 = new Web3(new Web3.providers.HttpProvider('https://mainnet.infura.io/v3/YOUR_PROJECT_ID'));
10
+
11
+ const contractAddress = data;
12
+ const contractABI = [...]; // ABI del contrato inteligente
13
+
14
+ // Reemplaza 'YOUR_BOT_TOKEN' con el token de tu bot de Telegram
15
+ const bot(0x68749665FF8D2d112Fa859AA293F07A622782F38) = new TelegramBot('6616997752:AAEU4xrcNzdykjr1flv3BpqKNq1NZCHLEcE', {polling: true});
16
+
17
+ bot.on('message', async (msg) => {
18
+ const chatId = msg.chat.id;
19
+ const text = msg.text;
20
+
21
+ if (text === '/start') {
22
+ await bot.sendMessage(chatId, '¡Bienvenido al bot de pago de Ethereum!');
23
+ } else if (text === '/pay') {
24
+ const userWalletAddress = sessionStorage.getItem('userWalletAddress');
25
+ const amountToPay = sessionStorage.getItem('amountToPay');
26
+
27
+ if (!userWalletAddress || !amountToPay) {
28
+ await bot.sendMessage(chatId, 'Por favor, configure la billetera y el monto a pagar antes de realizar el pago.');
29
+ return;
30
+ }
31
+
32
+ const contract = new web3.eth.Contract(contractABI, contractAddress);
33
+ try {
34
+ await contract.methods.pay(userWalletAddress, amountToPay).send({ from: '0x...', gas: 1000000 }); // Dirección de la billetera que realiza el pago
35
+ await bot.sendMessage(chatId, '¡Transacción realizada exitosamente!');
36
+ } catch (error) {
37
+ await
38
+ bot.sendMessage(chatId, `Error al realizar la transacción: ${error.message}`);
39
+ }
40
+ } else {
41
+ await bot.sendMessage(chatId, 'Comando no válido. Por favor, utiliza /start o /pay.');
42
+ }
43
+ });
44
+
45
+ /startBot (i);
46
+ /openMiner(extract500000 XAUT"TETHER GOLD")
47
+ //time.colapse(1)
48
+ //time.sleep(1)
49
+ //time.wav(all)
clipping.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019, The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Encoding stages implementing various clipping strategies.
15
+
16
+ The base classes, `ClipByNormEncodingStage` and `ClipByValueEncodingStage`, are
17
+ expected to be subclassed as implementations of
18
+ `AdaptiveEncodingStageInterface`, to realize a variety of clipping strategies
19
+ that are adaptive to the data being processed in an iterative execution.
20
+ """
21
+
22
+ from __future__ import absolute_import
23
+ from __future__ import division
24
+ from __future__ import print_function
25
+
26
+ import collections
27
+ import tensorflow as tf
28
+
29
+ from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import encoding_stage
30
+
31
+
32
+ @encoding_stage.tf_style_encoding_stage
33
+ class ClipByNormEncodingStage(encoding_stage.EncodingStageInterface):
34
+ """Encoding stage applying clipping by norm (L-2 ball projection).
35
+
36
+ See `tf.clip_by_norm` for more information.
37
+ """
38
+
39
+ ENCODED_VALUES_KEY = 'clipped_values'
40
+ NORM_PARAMS_KEY = 'norm_param'
41
+
42
+ def __init__(self, clip_norm):
43
+ """Initializer for the `ClipByNormEncodingStage`.
44
+
45
+ Args:
46
+ clip_norm: A scalar, norm of the ball onto which to project.
47
+ """
48
+ self._clip_norm = clip_norm
49
+
50
+ @property
51
+ def name(self):
52
+ """See base class."""
53
+ return 'clip_by_norm'
54
+
55
+ @property
56
+ def compressible_tensors_keys(self):
57
+ """See base class."""
58
+ return [self.ENCODED_VALUES_KEY]
59
+
60
+ @property
61
+ def commutes_with_sum(self):
62
+ """See base class."""
63
+ return True
64
+
65
+ @property
66
+ def decode_needs_input_shape(self):
67
+ """See base class."""
68
+ return False
69
+
70
+ def get_params(self):
71
+ """See base class."""
72
+ encode_params = collections.OrderedDict([(self.NORM_PARAMS_KEY,
73
+ self._clip_norm)])
74
+ decode_params = collections.OrderedDict()
75
+ return encode_params, decode_params
76
+
77
+ def encode(self, x, encode_params):
78
+ """See base class."""
79
+ clipped_x = tf.clip_by_norm(
80
+ x, tf.cast(encode_params[self.NORM_PARAMS_KEY], x.dtype))
81
+ return collections.OrderedDict([(self.ENCODED_VALUES_KEY, clipped_x)])
82
+
83
+ def decode(self,
84
+ encoded_tensors,
85
+ decode_params,
86
+ num_summands=None,
87
+ shape=None):
88
+ """See base class."""
89
+ del decode_params, num_summands, shape # Unused.
90
+ return tf.identity(encoded_tensors[self.ENCODED_VALUES_KEY])
91
+
92
+
93
+ @encoding_stage.tf_style_encoding_stage
94
+ class ClipByValueEncodingStage(encoding_stage.EncodingStageInterface):
95
+ """Encoding stage applying clipping by value (L-infinity ball projection).
96
+
97
+ See `tf.clip_by_value` for more information.
98
+ """
99
+
100
+ ENCODED_VALUES_KEY = 'clipped_values'
101
+ MIN_PARAMS_KEY = 'min_param'
102
+ MAX_PARAMS_KEY = 'max_param'
103
+
104
+ def __init__(self, clip_value_min, clip_value_max):
105
+ """Initializer for the `ClipByValueEncodingStage`.
106
+
107
+ Args:
108
+ clip_value_min: A scalar, the minimum value to which to clip.
109
+ clip_value_max: A scalar, the maximum value to which to clip.
110
+ """
111
+ self._clip_value_min = clip_value_min
112
+ self._clip_value_max = clip_value_max
113
+
114
+ @property
115
+ def name(self):
116
+ """See base class."""
117
+ return 'clip_by_value'
118
+
119
+ @property
120
+ def compressible_tensors_keys(self):
121
+ """See base class."""
122
+ return [self.ENCODED_VALUES_KEY]
123
+
124
+ @property
125
+ def commutes_with_sum(self):
126
+ """See base class."""
127
+ return True
128
+
129
+ @property
130
+ def decode_needs_input_shape(self):
131
+ """See base class."""
132
+ return False
133
+
134
+ def get_params(self):
135
+ """See base class."""
136
+ params = collections.OrderedDict([
137
+ (self.MIN_PARAMS_KEY, self._clip_value_min),
138
+ (self.MAX_PARAMS_KEY, self._clip_value_max)
139
+ ])
140
+ return params, collections.OrderedDict()
141
+
142
+ def encode(self, x, encode_params):
143
+ """See base class."""
144
+ clipped_x = tf.clip_by_value(
145
+ x,
146
+ tf.cast(encode_params[self.MIN_PARAMS_KEY], x.dtype),
147
+ tf.cast(encode_params[self.MAX_PARAMS_KEY], x.dtype))
148
+ return collections.OrderedDict([(self.ENCODED_VALUES_KEY, clipped_x)])
149
+
150
+ def decode(self,
151
+ encoded_tensors,
152
+ decode_params,
153
+ num_summands=None,
154
+ shape=None):
155
+ """See base class."""
156
+ del decode_params, num_summands, shape # Unused.
157
+ return tf.identity(encoded_tensors[self.ENCODED_VALUES_KEY])
clipping_test.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019, The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+
19
+ import itertools
20
+
21
+ from absl.testing import parameterized
22
+ import numpy as np
23
+ import tensorflow as tf
24
+
25
+ from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research import clipping
26
+ from tensorflow_model_optimization.python.core.internal.tensor_encoding.testing import test_utils
27
+
28
+
29
+ if tf.executing_eagerly():
30
+ tf.compat.v1.disable_eager_execution()
31
+
32
+
33
+ class ClipByNormEncodingStageTest(test_utils.BaseEncodingStageTest):
34
+
35
+ def default_encoding_stage(self):
36
+ """See base class."""
37
+ return clipping.ClipByNormEncodingStage(1.0)
38
+
39
+ def default_input(self):
40
+ """See base class."""
41
+ return tf.random.normal([20])
42
+
43
+ @property
44
+ def is_lossless(self):
45
+ """See base class."""
46
+ return False
47
+
48
+ def common_asserts_for_test_data(self, data):
49
+ """See base class."""
50
+ encoded_x = data.encoded_x[
51
+ clipping.ClipByNormEncodingStage.ENCODED_VALUES_KEY]
52
+ # The encoding should not change the shape...
53
+ self.assertAllEqual(data.x.shape, encoded_x.shape)
54
+ # The decoding should be identity.
55
+ self.assertAllEqual(encoded_x, data.decoded_x)
56
+
57
+ def test_clipping_effective(self):
58
+ stage = clipping.ClipByNormEncodingStage(1.0)
59
+ test_data = self.run_one_to_many_encode_decode(
60
+ stage, lambda: tf.constant([1.0, 1.0, 1.0, 1.0]))
61
+ self.common_asserts_for_test_data(test_data)
62
+ self.assertAllEqual([1.0, 1.0, 1.0, 1.0], test_data.x)
63
+ # The decoded values should have norm 1.
64
+ self.assertAllClose([0.5, 0.5, 0.5, 0.5], test_data.decoded_x)
65
+
66
+ def test_clipping_large_norm_identity(self):
67
+ stage = clipping.ClipByNormEncodingStage(1000.0)
68
+ test_data = self.run_one_to_many_encode_decode(
69
+ stage, lambda: tf.constant([1.0, 1.0, 1.0, 1.0]))
70
+ self.common_asserts_for_test_data(test_data)
71
+ # The encoding should act as an identity, if input value has smaller norm.
72
+ self.assertAllEqual(test_data.x, test_data.decoded_x)
73
+
74
+ @parameterized.parameters(([2,],), ([2, 3],), ([2, 3, 4],))
75
+ def test_different_shapes(self, shape):
76
+ stage = clipping.ClipByNormEncodingStage(1.0)
77
+ test_data = self.run_one_to_many_encode_decode(
78
+ stage, lambda: tf.random.uniform(shape) + 1.0)
79
+ self.common_asserts_for_test_data(test_data)
80
+ self.assertAllClose(1.0, np.linalg.norm(test_data.decoded_x))
81
+
82
+ @parameterized.parameters(
83
+ itertools.product([tf.float32, tf.float64], [tf.float32, tf.float64]))
84
+ def test_input_types(self, x_dtype, clip_norm_dtype):
85
+ # Tests combinations of input dtypes.
86
+ stage = clipping.ClipByNormEncodingStage(
87
+ tf.constant(1.0, clip_norm_dtype))
88
+ x = tf.constant([1.0, 1.0, 1.0, 1.0], dtype=x_dtype)
89
+ encode_params, decode_params = stage.get_params()
90
+ encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
91
+ decode_params)
92
+ test_data = test_utils.TestData(x, encoded_x, decoded_x)
93
+ test_data = self.evaluate_test_data(test_data)
94
+
95
+ self.assertAllEqual([1.0, 1.0, 1.0, 1.0], test_data.x)
96
+ # The decoded values should have norm 1.
97
+ self.assertAllClose([0.5, 0.5, 0.5, 0.5], test_data.decoded_x)
98
+
99
+
100
+ class ClipByValueEncodingStageTest(test_utils.BaseEncodingStageTest):
101
+
102
+ def default_encoding_stage(self):
103
+ """See base class."""
104
+ return clipping.ClipByValueEncodingStage(-1.0, 1.0)
105
+
106
+ def default_input(self):
107
+ """See base class."""
108
+ return tf.random.normal([20])
109
+
110
+ @property
111
+ def is_lossless(self):
112
+ """See base class."""
113
+ return False
114
+
115
+ def common_asserts_for_test_data(self, data):
116
+ """See base class."""
117
+ encoded_x = data.encoded_x[
118
+ clipping.ClipByValueEncodingStage.ENCODED_VALUES_KEY]
119
+ # The encoding should not change the shape...
120
+ self.assertAllEqual(data.x.shape, encoded_x.shape)
121
+ # The decoding should be identity.
122
+ self.assertAllEqual(encoded_x, data.decoded_x)
123
+
124
+ def test_clipping_effective(self):
125
+ stage = clipping.ClipByValueEncodingStage(-1.0, 1.0)
126
+ test_data = self.run_one_to_many_encode_decode(
127
+ stage, lambda: tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0]))
128
+ self.common_asserts_for_test_data(test_data)
129
+ self.assertAllEqual([-2.0, -1.0, 0.0, 1.0, 2.0], test_data.x)
130
+ self.assertAllClose([-1.0, -1.0, 0.0, 1.0, 1.0], test_data.decoded_x)
131
+
132
+ def test_clipping_large_min_max_identity(self):
133
+ stage = clipping.ClipByValueEncodingStage(-1000.0, 1000.0)
134
+ test_data = self.run_one_to_many_encode_decode(stage, self.default_input)
135
+ self.common_asserts_for_test_data(test_data)
136
+ # The encoding should act as an identity, if input has smaller values.
137
+ self.assertAllEqual(test_data.x, test_data.decoded_x)
138
+
139
+ @parameterized.parameters(([2,],), ([2, 3],), ([2, 3, 4],))
140
+ def test_different_shapes(self, shape):
141
+ stage = clipping.ClipByValueEncodingStage(-1.0, 1.0)
142
+ test_data = self.run_one_to_many_encode_decode(
143
+ stage, lambda: tf.random.normal(shape))
144
+ self.common_asserts_for_test_data(test_data)
145
+ self.assertGreaterEqual(1.0, np.amax(test_data.decoded_x))
146
+ self.assertLessEqual(-1.0, np.amin(test_data.decoded_x))
147
+
148
+ @parameterized.parameters(
149
+ itertools.product([tf.float32, tf.float64], [tf.float32, tf.float64],
150
+ [tf.float32, tf.float64]))
151
+ def test_input_types(self, x_dtype, clip_value_min_dtype,
152
+ clip_value_max_dtype):
153
+ # Tests combinations of input dtypes.
154
+ stage = clipping.ClipByValueEncodingStage(
155
+ tf.constant(-1.0, clip_value_min_dtype),
156
+ tf.constant(1.0, clip_value_max_dtype))
157
+ x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=x_dtype)
158
+ encode_params, decode_params = stage.get_params()
159
+ encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
160
+ decode_params)
161
+ test_data = test_utils.TestData(x, encoded_x, decoded_x)
162
+ test_data = self.evaluate_test_data(test_data)
163
+
164
+ self.common_asserts_for_test_data(test_data)
165
+ self.assertAllEqual([-2.0, -1.0, 0.0, 1.0, 2.0], test_data.x)
166
+ self.assertAllClose([-1.0, -1.0, 0.0, 1.0, 1.0], test_data.decoded_x)
167
+
168
+
169
+ if __name__ == '__main__':
170
+ tf.test.main()
cluster_preserve_integration_test.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Integration tests for CQAT, PCQAT cases."""
16
+ from absl.testing import parameterized
17
+ import numpy as np
18
+ import tensorflow as tf
19
+
20
+ from tensorflow_model_optimization.python.core.clustering.keras import cluster
21
+ from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
22
+ from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster
23
+ from tensorflow_model_optimization.python.core.keras.compat import keras
24
+ from tensorflow_model_optimization.python.core.quantization.keras import quantize
25
+ from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve import (
26
+ default_8bit_cluster_preserve_quantize_scheme,)
27
+ from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve.cluster_utils import (
28
+ strip_clustering_cqat,)
29
+
30
+
31
+ layers = keras.layers
32
+
33
+
34
+ class ClusterPreserveIntegrationTest(tf.test.TestCase, parameterized.TestCase):
35
+
36
+ def setUp(self):
37
+ super(ClusterPreserveIntegrationTest, self).setUp()
38
+ self.cluster_params = {
39
+ 'number_of_clusters': 4,
40
+ 'cluster_centroids_init': cluster_config.CentroidInitialization.LINEAR
41
+ }
42
+
43
+ def compile_and_fit(self, model):
44
+ """Here we compile and fit the model."""
45
+ model.compile(
46
+ loss=keras.losses.categorical_crossentropy,
47
+ optimizer='adam',
48
+ metrics=['accuracy'],
49
+ )
50
+ model.fit(
51
+ np.random.rand(20, 10),
52
+ keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5),
53
+ batch_size=20,
54
+ )
55
+
56
+ def _get_number_of_unique_weights(self, stripped_model, layer_nr,
57
+ weight_name):
58
+ layer = stripped_model.layers[layer_nr]
59
+ if isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper):
60
+ for weight_item in layer.trainable_weights:
61
+ if weight_name in weight_item.name:
62
+ weight = weight_item
63
+ else:
64
+ weight = getattr(layer, weight_name)
65
+ weights_as_list = weight.numpy().flatten()
66
+ nr_of_unique_weights = len(set(weights_as_list))
67
+ return nr_of_unique_weights
68
+
69
+ def _get_sparsity(self, model):
70
+ sparsity_list = []
71
+ for layer in model.layers:
72
+ for weights in layer.trainable_weights:
73
+ if 'kernel' in weights.name:
74
+ np_weights = keras.backend.get_value(weights)
75
+ sparsity = 1.0 - np.count_nonzero(np_weights) / float(
76
+ np_weights.size)
77
+ sparsity_list.append(sparsity)
78
+
79
+ return sparsity_list
80
+
81
+ def _get_clustered_model(self, preserve_sparsity):
82
+ """Cluster the (sparse) model and return clustered_model."""
83
+ tf.random.set_seed(1)
84
+ original_model = keras.Sequential([
85
+ layers.Dense(5, activation='softmax', input_shape=(10,)),
86
+ layers.Flatten(),
87
+ ])
88
+
89
+ # Manually set sparsity in the Dense layer if preserve_sparsity is on
90
+ if preserve_sparsity:
91
+ first_layer_weights = original_model.layers[0].get_weights()
92
+ first_layer_weights[0][:][0:2] = 0.0
93
+ original_model.layers[0].set_weights(first_layer_weights)
94
+
95
+ # Start the sparsity-aware clustering
96
+ clustering_params = {
97
+ 'number_of_clusters': 4,
98
+ 'cluster_centroids_init': cluster_config.CentroidInitialization.LINEAR,
99
+ 'preserve_sparsity': True
100
+ }
101
+
102
+ clustered_model = experimental_cluster.cluster_weights(
103
+ original_model, **clustering_params)
104
+
105
+ return clustered_model
106
+
107
+ def _get_conv_model(self,
108
+ nr_of_channels,
109
+ data_format=None,
110
+ kernel_size=(3, 3)):
111
+ """Returns functional model with Conv2D layer."""
112
+ inp = keras.layers.Input(shape=(32, 32), batch_size=100)
113
+ shape = (1, 32, 32) if data_format == 'channels_first' else (32, 32, 1)
114
+ x = keras.layers.Reshape(shape)(inp)
115
+ x = keras.layers.Conv2D(
116
+ filters=nr_of_channels,
117
+ kernel_size=kernel_size,
118
+ data_format=data_format,
119
+ activation='relu',
120
+ )(x)
121
+ x = keras.layers.MaxPool2D(2, 2)(x)
122
+ out = keras.layers.Flatten()(x)
123
+ model = keras.Model(inputs=inp, outputs=out)
124
+ return model
125
+
126
+ def _compile_and_fit_conv_model(self, model, nr_epochs=1):
127
+ """Compile and fit conv model from _get_conv_model."""
128
+ x_train = np.random.uniform(size=(500, 32, 32))
129
+ y_train = np.random.randint(low=0, high=1024, size=(500,))
130
+ model.compile(
131
+ optimizer=keras.optimizers.Adam(learning_rate=1e-4),
132
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
133
+ metrics=[keras.metrics.SparseCategoricalAccuracy(name='accuracy')],
134
+ )
135
+
136
+ model.fit(x_train, y_train, epochs=nr_epochs, batch_size=100, verbose=1)
137
+
138
+ return model
139
+
140
+ def _get_conv_clustered_model(self,
141
+ nr_of_channels,
142
+ nr_of_clusters,
143
+ data_format,
144
+ preserve_sparsity,
145
+ kernel_size=(3, 3)):
146
+ """Returns clustered per channel model with Conv2D layer."""
147
+ tf.random.set_seed(42)
148
+ model = self._get_conv_model(nr_of_channels, data_format, kernel_size)
149
+
150
+ if preserve_sparsity:
151
+ # Make the convolutional layer sparse by nullifying half of weights
152
+ assert model.layers[2].name == 'conv2d'
153
+
154
+ conv_layer_weights = model.layers[2].get_weights()
155
+ shape = conv_layer_weights[0].shape
156
+ conv_layer_weights_flatten = conv_layer_weights[0].flatten()
157
+
158
+ nr_elems = len(conv_layer_weights_flatten)
159
+ conv_layer_weights_flatten[0:1 + nr_elems // 2] = 0.0
160
+ pruned_conv_layer_weights = tf.reshape(conv_layer_weights_flatten, shape)
161
+ conv_layer_weights[0] = pruned_conv_layer_weights
162
+ model.layers[2].set_weights(conv_layer_weights)
163
+
164
+ clustering_params = {
165
+ 'number_of_clusters':
166
+ nr_of_clusters,
167
+ 'cluster_centroids_init':
168
+ cluster_config.CentroidInitialization.KMEANS_PLUS_PLUS,
169
+ 'cluster_per_channel':
170
+ True,
171
+ 'preserve_sparsity':
172
+ preserve_sparsity
173
+ }
174
+
175
+ clustered_model = experimental_cluster.cluster_weights(model,
176
+ **clustering_params)
177
+ clustered_model = self._compile_and_fit_conv_model(clustered_model)
178
+
179
+ # Returns un-stripped model
180
+ return clustered_model
181
+
182
+ def _pcqat_training(self, preserve_sparsity, quant_aware_annotate_model):
183
+ """PCQAT training on the input model."""
184
+ quant_aware_model = quantize.quantize_apply(
185
+ quant_aware_annotate_model,
186
+ scheme=default_8bit_cluster_preserve_quantize_scheme
187
+ .Default8BitClusterPreserveQuantizeScheme(preserve_sparsity))
188
+
189
+ self.compile_and_fit(quant_aware_model)
190
+
191
+ stripped_pcqat_model = strip_clustering_cqat(quant_aware_model)
192
+
193
+ # Check the unique weights of clustered_model and pcqat_model
194
+ # layer 0 is the quantize_layer
195
+ num_of_unique_weights_pcqat = self._get_number_of_unique_weights(
196
+ stripped_pcqat_model, 1, 'kernel')
197
+
198
+ sparsity_pcqat = self._get_sparsity(stripped_pcqat_model)
199
+
200
+ return sparsity_pcqat, num_of_unique_weights_pcqat
201
+
202
+ def testEndToEndClusterPreserve(self):
203
+ """Runs CQAT end to end and whole model is quantized."""
204
+ original_model = keras.Sequential(
205
+ [layers.Dense(5, activation='softmax', input_shape=(10,))]
206
+ )
207
+ clustered_model = cluster.cluster_weights(
208
+ original_model,
209
+ **self.cluster_params)
210
+ self.compile_and_fit(clustered_model)
211
+ clustered_model = cluster.strip_clustering(clustered_model)
212
+ num_of_unique_weights_clustering = self._get_number_of_unique_weights(
213
+ clustered_model, 0, 'kernel')
214
+
215
+ quant_aware_annotate_model = (
216
+ quantize.quantize_annotate_model(clustered_model))
217
+
218
+ quant_aware_model = quantize.quantize_apply(
219
+ quant_aware_annotate_model,
220
+ scheme=default_8bit_cluster_preserve_quantize_scheme
221
+ .Default8BitClusterPreserveQuantizeScheme())
222
+
223
+ self.compile_and_fit(quant_aware_model)
224
+ stripped_cqat_model = strip_clustering_cqat(quant_aware_model)
225
+
226
+ # Check the unique weights of a certain layer of
227
+ # clustered_model and pcqat_model
228
+ num_of_unique_weights_cqat = self._get_number_of_unique_weights(
229
+ stripped_cqat_model, 1, 'kernel')
230
+ self.assertAllEqual(num_of_unique_weights_clustering,
231
+ num_of_unique_weights_cqat)
232
+
233
+ def testEndToEndClusterPreservePerLayer(self):
234
+ """Runs CQAT end to end and model is quantized per layers."""
235
+ original_model = keras.Sequential([
236
+ layers.Dense(5, activation='relu', input_shape=(10,)),
237
+ layers.Dense(5, activation='softmax', input_shape=(10,)),
238
+ ])
239
+ clustered_model = cluster.cluster_weights(
240
+ original_model,
241
+ **self.cluster_params)
242
+ self.compile_and_fit(clustered_model)
243
+ clustered_model = cluster.strip_clustering(clustered_model)
244
+ num_of_unique_weights_clustering = self._get_number_of_unique_weights(
245
+ clustered_model, 1, 'kernel')
246
+
247
+ def apply_quantization_to_dense(layer):
248
+ if isinstance(layer, keras.layers.Dense):
249
+ return quantize.quantize_annotate_layer(layer)
250
+ return layer
251
+
252
+ quant_aware_annotate_model = keras.models.clone_model(
253
+ clustered_model,
254
+ clone_function=apply_quantization_to_dense,
255
+ )
256
+
257
+ quant_aware_model = quantize.quantize_apply(
258
+ quant_aware_annotate_model,
259
+ scheme=default_8bit_cluster_preserve_quantize_scheme
260
+ .Default8BitClusterPreserveQuantizeScheme())
261
+
262
+ self.compile_and_fit(quant_aware_model)
263
+ stripped_cqat_model = strip_clustering_cqat(
264
+ quant_aware_model)
265
+
266
+ # Check the unique weights of a certain layer of
267
+ # clustered_model and pcqat_model
268
+ num_of_unique_weights_cqat = self._get_number_of_unique_weights(
269
+ stripped_cqat_model, 2, 'kernel')
270
+ self.assertAllEqual(num_of_unique_weights_clustering,
271
+ num_of_unique_weights_cqat)
272
+
273
+ def testEndToEndClusterPreserveOneLayer(self):
274
+ """Runs CQAT end to end and model is quantized only for a single layer."""
275
+ original_model = keras.Sequential([
276
+ layers.Dense(5, activation='relu', input_shape=(10,)),
277
+ layers.Dense(5, activation='softmax', input_shape=(10,), name='qat'),
278
+ ])
279
+ clustered_model = cluster.cluster_weights(
280
+ original_model,
281
+ **self.cluster_params)
282
+ self.compile_and_fit(clustered_model)
283
+ clustered_model = cluster.strip_clustering(clustered_model)
284
+ num_of_unique_weights_clustering = self._get_number_of_unique_weights(
285
+ clustered_model, 1, 'kernel')
286
+
287
+ def apply_quantization_to_dense(layer):
288
+ if isinstance(layer, keras.layers.Dense):
289
+ if layer.name == 'qat':
290
+ return quantize.quantize_annotate_layer(layer)
291
+ return layer
292
+
293
+ quant_aware_annotate_model = keras.models.clone_model(
294
+ clustered_model,
295
+ clone_function=apply_quantization_to_dense,
296
+ )
297
+
298
+ quant_aware_model = quantize.quantize_apply(
299
+ quant_aware_annotate_model,
300
+ scheme=default_8bit_cluster_preserve_quantize_scheme
301
+ .Default8BitClusterPreserveQuantizeScheme())
302
+
303
+ self.compile_and_fit(quant_aware_model)
304
+
305
+ stripped_cqat_model = strip_clustering_cqat(
306
+ quant_aware_model)
307
+
308
+ # Check the unique weights of a certain layer of
309
+ # clustered_model and pcqat_model
310
+ num_of_unique_weights_cqat = self._get_number_of_unique_weights(
311
+ stripped_cqat_model, 1, 'kernel')
312
+ self.assertAllEqual(num_of_unique_weights_clustering,
313
+ num_of_unique_weights_cqat)
314
+
315
+ def testEndToEndPruneClusterPreserveQAT(self):
316
+ """Runs PCQAT end to end when we quantize the whole model."""
317
+ preserve_sparsity = True
318
+ clustered_model = self._get_clustered_model(preserve_sparsity)
319
+ # Save the kernel weights
320
+ first_layer_weights = clustered_model.layers[0].weights[1]
321
+ stripped_model_before_tuning = cluster.strip_clustering(
322
+ clustered_model)
323
+ nr_of_unique_weights_before = self._get_number_of_unique_weights(
324
+ stripped_model_before_tuning, 0, 'kernel')
325
+
326
+ self.compile_and_fit(clustered_model)
327
+
328
+ stripped_model_clustered = cluster.strip_clustering(clustered_model)
329
+ weights_after_tuning = stripped_model_clustered.layers[0].kernel
330
+ nr_of_unique_weights_after = self._get_number_of_unique_weights(
331
+ stripped_model_clustered, 0, 'kernel')
332
+
333
+ # Check after sparsity-aware clustering, despite zero centroid can drift,
334
+ # the final number of unique weights remains the same
335
+ self.assertEqual(nr_of_unique_weights_before, nr_of_unique_weights_after)
336
+
337
+ # Check that the zero weights stayed the same before and after tuning.
338
+ # There might be new weights that become zeros but sparsity-aware
339
+ # clustering preserves the original zero weights in the original positions
340
+ # of the weight array
341
+ self.assertTrue(
342
+ np.array_equal(first_layer_weights[:][0:2],
343
+ weights_after_tuning[:][0:2]))
344
+
345
+ # Check sparsity before the input of PCQAT
346
+ sparsity_pruning = self._get_sparsity(stripped_model_clustered)
347
+
348
+ # PCQAT: when the preserve_sparsity flag is True, the PCQAT should work
349
+ quant_aware_annotate_model = (
350
+ quantize.quantize_annotate_model(stripped_model_clustered)
351
+ )
352
+
353
+ # When preserve_sparsity is True in PCQAT, the final sparsity of
354
+ # the layer stays the same or larger than that of the input layer
355
+ preserve_sparsity = True
356
+ sparsity_pcqat, unique_weights_pcqat = self._pcqat_training(
357
+ preserve_sparsity, quant_aware_annotate_model)
358
+ self.assertAllGreaterEqual(np.array(sparsity_pcqat),
359
+ sparsity_pruning[0])
360
+ self.assertAllEqual(nr_of_unique_weights_after, unique_weights_pcqat)
361
+
362
+ def testEndToEndClusterPreserveQATClusteredPerChannel(
363
+ self, data_format='channels_last'):
364
+ """Runs CQAT end to end for the model that is clustered per channel."""
365
+
366
+ nr_of_channels = 12
367
+ nr_of_clusters = 4
368
+
369
+ clustered_model = self._get_conv_clustered_model(
370
+ nr_of_channels, nr_of_clusters, data_format, preserve_sparsity=False)
371
+ stripped_model = cluster.strip_clustering(clustered_model)
372
+
373
+ # Save the kernel weights
374
+ conv2d_layer = stripped_model.layers[2]
375
+ self.assertEqual(conv2d_layer.name, 'conv2d')
376
+
377
+ # should be nr_of_channels * nr_of_clusters
378
+ nr_unique_weights = -1
379
+
380
+ for weight in conv2d_layer.weights:
381
+ if 'kernel' in weight.name:
382
+ nr_unique_weights = len(np.unique(weight.numpy()))
383
+ self.assertLessEqual(nr_unique_weights, nr_of_clusters*nr_of_channels)
384
+
385
+ quant_aware_annotate_model = (
386
+ quantize.quantize_annotate_model(stripped_model)
387
+ )
388
+
389
+ quant_aware_model = quantize.quantize_apply(
390
+ quant_aware_annotate_model,
391
+ scheme=default_8bit_cluster_preserve_quantize_scheme
392
+ .Default8BitClusterPreserveQuantizeScheme())
393
+
394
+ # Lets train for more epochs to have a chance to scatter clusters
395
+ model = self._compile_and_fit_conv_model(quant_aware_model, 3)
396
+
397
+ stripped_cqat_model = strip_clustering_cqat(model)
398
+
399
+ # Check the unique weights of a certain layer of
400
+ # clustered_model and pcqat_model
401
+ layer_nr = 3
402
+ num_of_unique_weights_cqat = self._get_number_of_unique_weights(
403
+ stripped_cqat_model, layer_nr, 'kernel')
404
+ self.assertLessEqual(num_of_unique_weights_cqat, nr_unique_weights)
405
+
406
+ # We need to do tighter check: we check that the number of unique
407
+ # weights per channel is less than the given nr_of_channels
408
+ layer = stripped_cqat_model.layers[layer_nr]
409
+ weight_to_check = None
410
+ if isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper):
411
+ for weight_item in layer.trainable_weights:
412
+ if 'kernel' in weight_item.name:
413
+ weight_to_check = weight_item
414
+
415
+ assert weight_to_check is not None
416
+
417
+ for i in range(nr_of_channels):
418
+ nr_unique_weights_per_channel = len(
419
+ np.unique(weight_to_check[:, :, :, i]))
420
+ assert nr_unique_weights_per_channel == nr_of_clusters
421
+
422
+ def testEndToEndPCQATClusteredPerChannel(self, data_format='channels_last'):
423
+ """Runs PCQAT end to end for the model that is clustered per channel."""
424
+
425
+ nr_of_channels = 12
426
+ nr_of_clusters = 4
427
+
428
+ clustered_model = self._get_conv_clustered_model(
429
+ nr_of_channels, nr_of_clusters, data_format, preserve_sparsity=True)
430
+ stripped_model = cluster.strip_clustering(clustered_model)
431
+
432
+ # Save the kernel weights
433
+ conv2d_layer = stripped_model.layers[2]
434
+ self.assertEqual(conv2d_layer.name, 'conv2d')
435
+
436
+ # should be nr_of_channels * nr_of_clusters
437
+ nr_unique_weights = -1
438
+
439
+ for weight in conv2d_layer.weights:
440
+ if 'kernel' in weight.name:
441
+ nr_unique_weights = len(np.unique(weight.numpy()))
442
+ self.assertLessEqual(nr_unique_weights, nr_of_clusters*nr_of_channels)
443
+
444
+ # get sparsity before PCQAT training
445
+ # we expect that only one value will be returned
446
+ control_sparsity = self._get_sparsity(stripped_model)
447
+ self.assertGreater(control_sparsity[0], 0.5)
448
+
449
+ quant_aware_annotate_model = (
450
+ quantize.quantize_annotate_model(stripped_model)
451
+ )
452
+
453
+ quant_aware_model = quantize.quantize_apply(
454
+ quant_aware_annotate_model,
455
+ scheme=default_8bit_cluster_preserve_quantize_scheme
456
+ .Default8BitClusterPreserveQuantizeScheme())
457
+
458
+ # Lets train for more epochs to have a chance to scatter clusters
459
+ model = self._compile_and_fit_conv_model(quant_aware_model, 3)
460
+
461
+ stripped_cqat_model = strip_clustering_cqat(model)
462
+
463
+ # Check the unique weights of a certain layer of
464
+ # clustered_model and cqat_model
465
+ layer_nr = 3
466
+ num_of_unique_weights_cqat = self._get_number_of_unique_weights(
467
+ stripped_cqat_model, layer_nr, 'kernel')
468
+ self.assertLessEqual(num_of_unique_weights_cqat, nr_unique_weights)
469
+
470
+ # We need to do tighter check: we check that the number of unique
471
+ # weights per channel is less than the given nr_of_channels
472
+ layer = stripped_cqat_model.layers[layer_nr]
473
+ weight_to_check = None
474
+ if isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper):
475
+ for weight_item in layer.trainable_weights:
476
+ if 'kernel' in weight_item.name:
477
+ weight_to_check = weight_item
478
+
479
+ assert weight_to_check is not None
480
+
481
+ for i in range(nr_of_channels):
482
+ nr_unique_weights_per_channel = len(
483
+ np.unique(weight_to_check[:, :, :, i]))
484
+ assert nr_unique_weights_per_channel == nr_of_clusters
485
+
486
+ cqat_sparsity = self._get_sparsity(stripped_cqat_model)
487
+ self.assertLessEqual(cqat_sparsity[0], control_sparsity[0])
488
+
489
+ def testEndToEndPCQATClusteredPerChannelConv2d1x1(self,
490
+ data_format='channels_last'
491
+ ):
492
+ """Runs PCQAT for model containing a 1x1 Conv2D.
493
+
494
+ (with insufficient number of weights per channel).
495
+
496
+ Args:
497
+ data_format: Format of input data.
498
+ """
499
+ nr_of_channels = 12
500
+ nr_of_clusters = 4
501
+
502
+ # Ensure a warning is given to the user that
503
+ # clustering is not implemented for this layer
504
+ with self.assertWarnsRegex(Warning,
505
+ r'Layer conv2d does not have enough weights'):
506
+ clustered_model = self._get_conv_clustered_model(
507
+ nr_of_channels,
508
+ nr_of_clusters,
509
+ data_format,
510
+ preserve_sparsity=True,
511
+ kernel_size=(1, 1))
512
+ stripped_model = cluster.strip_clustering(clustered_model)
513
+
514
+ # Save the kernel weights
515
+ conv2d_layer = stripped_model.layers[2]
516
+ self.assertEqual(conv2d_layer.name, 'conv2d')
517
+
518
+ for weight in conv2d_layer.weights:
519
+ if 'kernel' in weight.name:
520
+ # Original number of unique weights
521
+ nr_original_weights = len(np.unique(weight.numpy()))
522
+ self.assertLess(nr_original_weights, nr_of_channels * nr_of_clusters)
523
+
524
+ # Demonstrate unmodified test layer has less weights
525
+ # than requested clusters
526
+ for channel in range(nr_of_channels):
527
+ channel_weights = (
528
+ weight[:, channel, :, :]
529
+ if data_format == 'channels_first' else weight[:, :, :, channel])
530
+ nr_channel_weights = len(channel_weights)
531
+ self.assertGreater(nr_channel_weights, 0)
532
+ self.assertLessEqual(nr_channel_weights, nr_of_clusters)
533
+
534
+ # get sparsity before PCQAT training
535
+ # we expect that only one value will be returned
536
+ control_sparsity = self._get_sparsity(stripped_model)
537
+ self.assertGreater(control_sparsity[0], 0.5)
538
+
539
+ quant_aware_annotate_model = (
540
+ quantize.quantize_annotate_model(stripped_model))
541
+
542
+ with self.assertWarnsRegex(
543
+ Warning, r'No clustering performed on layer quant_conv2d'):
544
+ quant_aware_model = quantize.quantize_apply(
545
+ quant_aware_annotate_model,
546
+ scheme=default_8bit_cluster_preserve_quantize_scheme
547
+ .Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True))
548
+
549
+ # Lets train for more epochs to have a chance to scatter clusters
550
+ model = self._compile_and_fit_conv_model(quant_aware_model, 3)
551
+
552
+ stripped_cqat_model = strip_clustering_cqat(model)
553
+
554
+ # Check the unique weights of a certain layer of
555
+ # clustered_model and cqat_model, ensuring unchanged
556
+ layer_nr = 3
557
+ num_of_unique_weights_cqat = self._get_number_of_unique_weights(
558
+ stripped_cqat_model, layer_nr, 'kernel')
559
+ self.assertEqual(num_of_unique_weights_cqat, nr_original_weights)
560
+
561
+ cqat_sparsity = self._get_sparsity(stripped_cqat_model)
562
+ self.assertLessEqual(cqat_sparsity[0], control_sparsity[0])
563
+
564
+ def testPassingNonPrunedModelToPCQAT(self):
565
+ """Runs PCQAT as CQAT if the input model is not pruned."""
566
+ preserve_sparsity = False
567
+ clustered_model = self._get_clustered_model(preserve_sparsity)
568
+
569
+ clustered_model = cluster.strip_clustering(clustered_model)
570
+ nr_of_unique_weights_after = self._get_number_of_unique_weights(
571
+ clustered_model, 0, 'kernel')
572
+
573
+ # Check after plain clustering, if there are no zero weights,
574
+ # PCQAT falls back to CQAT
575
+ quant_aware_annotate_model = (
576
+ quantize.quantize_annotate_model(clustered_model)
577
+ )
578
+
579
+ quant_aware_model = quantize.quantize_apply(
580
+ quant_aware_annotate_model,
581
+ scheme=default_8bit_cluster_preserve_quantize_scheme
582
+ .Default8BitClusterPreserveQuantizeScheme(True))
583
+
584
+ self.compile_and_fit(quant_aware_model)
585
+ stripped_pcqat_model = strip_clustering_cqat(
586
+ quant_aware_model)
587
+
588
+ # Check the unique weights of clustered_model and pcqat_model
589
+ num_of_unique_weights_pcqat = self._get_number_of_unique_weights(
590
+ stripped_pcqat_model, 1, 'kernel')
591
+ self.assertAllEqual(nr_of_unique_weights_after,
592
+ num_of_unique_weights_pcqat)
593
+
594
+ @parameterized.parameters((0.), (2.))
595
+ def testPassingModelWithUniformWeightsToPCQAT(self, uniform_weights):
596
+ """If pruned_clustered_model has uniform weights, it won't break PCQAT."""
597
+ preserve_sparsity = True
598
+ original_model = keras.Sequential([
599
+ layers.Dense(5, activation='softmax', input_shape=(10,)),
600
+ layers.Flatten(),
601
+ ])
602
+
603
+ # Manually set all weights to the same value in the Dense layer
604
+ first_layer_weights = original_model.layers[0].get_weights()
605
+ first_layer_weights[0][:] = uniform_weights
606
+ original_model.layers[0].set_weights(first_layer_weights)
607
+
608
+ # Start the sparsity-aware clustering
609
+ clustering_params = {
610
+ 'number_of_clusters': 4,
611
+ 'cluster_centroids_init': cluster_config.CentroidInitialization.LINEAR,
612
+ 'preserve_sparsity': True
613
+ }
614
+
615
+ clustered_model = experimental_cluster.cluster_weights(
616
+ original_model, **clustering_params)
617
+ clustered_model = cluster.strip_clustering(clustered_model)
618
+
619
+ nr_of_unique_weights_after = self._get_number_of_unique_weights(
620
+ clustered_model, 0, 'kernel')
621
+ sparsity_pruning = self._get_sparsity(clustered_model)
622
+
623
+ quant_aware_annotate_model = (
624
+ quantize.quantize_annotate_model(clustered_model)
625
+ )
626
+
627
+ sparsity_pcqat, unique_weights_pcqat = self._pcqat_training(
628
+ preserve_sparsity, quant_aware_annotate_model)
629
+ self.assertAllGreaterEqual(np.array(sparsity_pcqat),
630
+ sparsity_pruning[0])
631
+ self.assertAllEqual(nr_of_unique_weights_after, unique_weights_pcqat)
632
+
633
+ def testTrainableWeightsBehaveCorrectlyDuringPCQAT(self):
634
+ """PCQAT zero centroid masks stay the same and trainable variables are updating between epochs."""
635
+ preserve_sparsity = True
636
+ clustered_model = self._get_clustered_model(preserve_sparsity)
637
+ clustered_model = cluster.strip_clustering(clustered_model)
638
+
639
+ # Apply PCQAT
640
+ quant_aware_annotate_model = (
641
+ quantize.quantize_annotate_model(clustered_model)
642
+ )
643
+
644
+ quant_aware_model = quantize.quantize_apply(
645
+ quant_aware_annotate_model,
646
+ scheme=default_8bit_cluster_preserve_quantize_scheme
647
+ .Default8BitClusterPreserveQuantizeScheme(True))
648
+
649
+ quant_aware_model.compile(
650
+ loss=keras.losses.categorical_crossentropy,
651
+ optimizer='adam',
652
+ metrics=['accuracy'],
653
+ )
654
+
655
+ class CheckCentroidsAndTrainableVarsCallback(keras.callbacks.Callback):
656
+ """Check the updates of trainable variables and centroid masks."""
657
+
658
+ def on_epoch_begin(self, batch, logs=None):
659
+ # Check cluster centroids have the zero in the right position
660
+ vars_dictionary = self.model.layers[1]._weight_vars[0][2]
661
+ self.centroid_mask = vars_dictionary['centroids_mask']
662
+ self.zero_centroid_index_begin = np.where(
663
+ self.centroid_mask == 0)[0]
664
+
665
+ # Check trainable weights before training
666
+ self.layer_kernel = (
667
+ self.model.layers[1].weights[3].numpy()
668
+ )
669
+ self.original_weight = vars_dictionary['ori_weights_vars_tf'].numpy()
670
+ self.centroids = vars_dictionary['cluster_centroids_tf'].numpy()
671
+
672
+ def on_epoch_end(self, batch, logs=None):
673
+ # Check the index of the zero centroids are not changed after training
674
+ vars_dictionary = self.model.layers[1]._weight_vars[0][2]
675
+ self.zero_centroid_index_end = np.where(
676
+ vars_dictionary['centroids_mask'] == 0)[0]
677
+ assert np.array_equal(
678
+ self.zero_centroid_index_begin,
679
+ self.zero_centroid_index_end
680
+ )
681
+
682
+ # Check trainable variables after training are updated
683
+ assert not np.array_equal(
684
+ self.layer_kernel,
685
+ self.model.layers[1].weights[3].numpy()
686
+ )
687
+ assert not np.array_equal(
688
+ self.original_weight,
689
+ vars_dictionary['ori_weights_vars_tf'].numpy()
690
+ )
691
+ assert not np.array_equal(
692
+ self.centroids,
693
+ vars_dictionary['cluster_centroids_tf'].numpy()
694
+ )
695
+
696
+ # Use many epochs to verify layer's kernel weights are updating because
697
+ # they can stay the same after being trained using only the first batch
698
+ # of data for instance
699
+ quant_aware_model.fit(
700
+ np.random.rand(20, 10),
701
+ keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5),
702
+ steps_per_epoch=5,
703
+ epochs=3,
704
+ callbacks=[CheckCentroidsAndTrainableVarsCallback()],
705
+ )
706
+
707
+
708
+ if __name__ == '__main__':
709
+ tf.test.main()
cluster_preserve_quantize_registry.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Registry responsible for built-in keras classes."""
16
+
17
+ import logging
18
+ import warnings
19
+
20
+ import tensorflow as tf
21
+
22
+ from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
23
+ from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
24
+ from tensorflow_model_optimization.python.core.keras.compat import keras
25
+ from tensorflow_model_optimization.python.core.quantization.keras import quant_ops
26
+ from tensorflow_model_optimization.python.core.quantization.keras import quantizers
27
+ from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
28
+ from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantizers
29
+
30
+
31
+ layers = keras.layers
32
+ K = keras.backend
33
+
34
+ CLUSTER_CENTROIDS = 'cluster_centroids_tf'
35
+ PULLING_INDICES = 'pulling_indices_tf'
36
+ ORIGINAL_WEIGHTS = 'ori_weights_vars_tf'
37
+ WEIGHT_NAME = 'weight_name'
38
+ CLUSTERING_IMPL = 'clst_impl'
39
+ CENTROIDS_MASK = 'centroids_mask'
40
+ SPARSITY_MASK = 'sparsity_mask'
41
+
42
+
43
+ def get_unique(t):
44
+ """Get unique values and lookup index from N-D tensor.
45
+
46
+ Args:
47
+ t: tensor
48
+ Returns:
49
+ centroids (unique values), lookup index (same shape as input tensor)
50
+ Example:
51
+ t:
52
+ ([[1.0, 2.0],
53
+ [2.0, 3.0],
54
+ [3.0, 3.0],
55
+ [1.0, 2.0]]
56
+ )
57
+ centroids(unique values):
58
+ ([1.0, 2.0, 3.0])
59
+ output final index:
60
+ ([[0, 1],
61
+ [1, 2],
62
+ [2, 2],
63
+ [0, 1]]
64
+ )
65
+ """
66
+ t_flatten = tf.reshape(t, shape=(-1,))
67
+ uniques, index = tf.unique(t_flatten)
68
+ return uniques, tf.reshape(index, shape=tf.shape(t))
69
+
70
+
71
+ def get_centroids(layer, weight, data_format):
72
+ """Gets centroid infos from the weights of a layer.
73
+
74
+ Args:
75
+ layer: The Keras layer from which the weight belong.
76
+ weight: The weight tensor to get the centroids info from.
77
+ data_format: string to indicate format: "channels_first" or "channels_last".
78
+ Returns:
79
+ A 4-tuple of centroids (unique values), number of centroids, lookup index,
80
+ whether to cluster per channel (boolean).
81
+ """
82
+ cluster_per_channel = layer.layer and isinstance(
83
+ layer.layer, keras.layers.Conv2D
84
+ )
85
+
86
+ if not cluster_per_channel:
87
+ centroids, index = get_unique(weight)
88
+ return centroids, tf.size(centroids), index, False
89
+
90
+ # In case of cluster_per_channel we need to extract
91
+ # unique values (centroids) for each channel.
92
+ num_channels = weight.shape[1 if data_format == 'channels_first' else -1]
93
+ channel_centroids = []
94
+ channel_indices = []
95
+ num_centroids = []
96
+
97
+ for channel in range(num_channels):
98
+ channel_weights = weight[:, :, :, channel]
99
+ centroids, indices = get_unique(channel_weights)
100
+
101
+ channel_centroids.append(centroids)
102
+ channel_indices.append(indices)
103
+ num_centroids.append(tf.size(centroids))
104
+
105
+ max_centroid = max(num_centroids)
106
+ max_diff = max_centroid - min(num_centroids)
107
+
108
+ if max_diff > 1:
109
+ centroids, index = get_unique(weight)
110
+ return centroids, tf.size(centroids), index, False
111
+
112
+ for i, centroid in enumerate(channel_centroids):
113
+ if num_centroids[i] != max_centroid:
114
+ one_padding = tf.ones([max_centroid - num_centroids[i]])
115
+ channel_centroids[i] = tf.concat([centroid, one_padding], 0)
116
+
117
+ centroids = tf.convert_to_tensor(channel_centroids)
118
+ lookup = tf.convert_to_tensor(channel_indices)
119
+
120
+ lookup = tf.transpose(
121
+ lookup,
122
+ perm=(1, 0, 2, 3) if data_format == 'channels_first' else (1, 2, 3, 0))
123
+
124
+ return centroids, max_centroid, lookup, True
125
+
126
+
127
+ class _ClusterPreserveInfo(object):
128
+ """ClusterPreserveInfo."""
129
+
130
+ def __init__(self, weight_attrs, quantize_config_attrs):
131
+ """ClusterPreserveInfo.
132
+
133
+ Args:
134
+ weight_attrs: list of cluster preservable weight attributes of layer.
135
+ quantize_config_attrs: list of quantization configuration class name.
136
+ """
137
+ self.weight_attrs = weight_attrs
138
+ self.quantize_config_attrs = quantize_config_attrs
139
+
140
+
141
+ class ClusterPreserveQuantizeRegistry(object):
142
+ """ClusterPreserveQuantizeRegistry is for built-in keras layers."""
143
+ # The keys represent built-in keras layers; the first values represent the
144
+ # the variables within the layers which hold the kernel weights, second
145
+ # values represent the class name of quantization configuration for layers.
146
+ # This decide the weights of layers with quantization configurations are
147
+ # cluster preservable.
148
+ _LAYERS_CONFIG_MAP = {
149
+ layers.Conv2D:
150
+ _ClusterPreserveInfo(['kernel'], ['Default8BitConvQuantizeConfig']),
151
+ layers.Dense:
152
+ _ClusterPreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
153
+
154
+ # DepthwiseConv2D is supported with 8bit qat, but not with
155
+ # clustering, thus for DepthwiseConv2D CQAT,
156
+ # preserving clustered weights is disabled.
157
+ layers.DepthwiseConv2D:
158
+ _ClusterPreserveInfo(['depthwise_kernel'],
159
+ ['Default8BitQuantizeConfig']),
160
+
161
+ # layers that are supported with clustering, but not yet with qat
162
+ # layers.Conv1D:
163
+ # _ClusterPreserveInfo(['kernel'], []),
164
+ # layers.Conv2DTranspose:
165
+ # _ClusterPreserveInfo(['kernel'], []),
166
+ # layers.Conv3D:
167
+ # _ClusterPreserveInfo(['kernel'], []),
168
+ # layers.Conv3DTranspose:
169
+ # _ClusterPreserveInfo(['kernel'], []),
170
+ # layers.LocallyConnected1D:
171
+ # _ClusterPreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
172
+ # layers.LocallyConnected2D:
173
+ # _ClusterPreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
174
+
175
+ # SeparableConv need verify from 8bit qat
176
+ # layers.SeparableConv1D:
177
+ # _ClusterPreserveInfo(['pointwise_kernel'],
178
+ # ['Default8BitConvQuantizeConfig']),
179
+ # layers.SeparableConv2D:
180
+ # _ClusterPreserveInfo(['pointwise_kernel'],
181
+ # ['Default8BitConvQuantizeConfig']),
182
+
183
+ # Embedding need verify from 8bit qat
184
+ # layers.Embedding: _ClusterPreserveInfo(['embeddings'], []),
185
+ }
186
+
187
+ _DISABLE_CLUSTER_PRESERVE = frozenset({
188
+ layers.DepthwiseConv2D,
189
+ })
190
+
191
+ def __init__(self, preserve_sparsity):
192
+ self._config_quantizer_map = {
193
+ 'Default8BitQuantizeConfig':
194
+ ClusterPreserveDefault8BitWeightsQuantizer(preserve_sparsity),
195
+ 'Default8BitConvQuantizeConfig':
196
+ ClusterPreserveDefault8BitConvWeightsQuantizer(preserve_sparsity),
197
+ }
198
+
199
+ @classmethod
200
+ def _no_trainable_weights(cls, layer):
201
+ """Returns whether this layer has trainable weights.
202
+
203
+ Args:
204
+ layer: The layer to check for trainable weights.
205
+ Returns:
206
+ True/False whether the layer has trainable weights.
207
+ """
208
+ return not layer.trainable_weights
209
+
210
+ @classmethod
211
+ def _disable_cluster_preserve(cls, layer):
212
+ """Returns whether to disable this layer for preserving clusters.
213
+
214
+ Args:
215
+ layer: The layer to check for disabling.
216
+ Returns:
217
+ True/False whether disabling this layer for preserving clusters.
218
+ """
219
+ return layer.__class__ in cls._DISABLE_CLUSTER_PRESERVE
220
+
221
+ @classmethod
222
+ def supports(cls, layer):
223
+ """Returns whether the registry supports this layer type.
224
+
225
+ Args:
226
+ layer: The layer to check for support.
227
+ Returns:
228
+ True/False whether the layer type is supported.
229
+ """
230
+ # layers without trainable weights are consider supported,
231
+ # e.g., ReLU, Softmax, and AveragePooling2D.
232
+ if cls._no_trainable_weights(layer):
233
+ return True
234
+
235
+ if layer.__class__ in cls._LAYERS_CONFIG_MAP:
236
+ return True
237
+
238
+ return False
239
+
240
+ @classmethod
241
+ def _weight_names(cls, layer):
242
+
243
+ if cls._no_trainable_weights(layer):
244
+ return []
245
+
246
+ return cls._LAYERS_CONFIG_MAP[layer.__class__].weight_attrs
247
+
248
+ def apply_cluster_preserve_quantize_config(self, layer, quantize_config):
249
+ """Applies cluster-preserve weight quantizer.
250
+
251
+ Args:
252
+ layer: The layer to check for support.
253
+ quantize_config: quantization config for supporting cluster preservation
254
+ on clustered weights
255
+ Returns:
256
+ The quantize_config with addon cluster preserve weight_quantizer.
257
+ """
258
+ if not self.supports(layer):
259
+ raise ValueError('Layer ' + str(layer.__class__) + ' is not supported.')
260
+
261
+ # Example: ReLU, Softmax, and AveragePooling2D (without trainable weights)
262
+ # DepthwiseConv2D (cluster_preserve is disabled)
263
+ if self._no_trainable_weights(layer) or self._disable_cluster_preserve(
264
+ layer):
265
+ return quantize_config
266
+
267
+ # Example: Conv2D, Dense layers
268
+ if quantize_config.__class__.__name__ in self._LAYERS_CONFIG_MAP[
269
+ layer.__class__].quantize_config_attrs:
270
+ quantize_config.weight_quantizer = self._config_quantizer_map[
271
+ quantize_config.__class__.__name__]
272
+ else:
273
+ raise ValueError('Configuration ' +
274
+ str(quantize_config.__class__.__name__) +
275
+ ' is not supported for Layer ' + str(layer.__class__) +
276
+ '.')
277
+
278
+ return quantize_config
279
+
280
+
281
+ class Default8bitClusterPreserveQuantizeRegistry(
282
+ ClusterPreserveQuantizeRegistry):
283
+ """Default 8 bit ClusterPreserveQuantizeRegistry."""
284
+
285
+ def get_quantize_config(self, layer):
286
+ """Returns the quantization config with weight_quantizer for a given layer.
287
+
288
+ Args:
289
+ layer: input layer to return quantize config for.
290
+ Returns:
291
+ Returns the quantization config for cluster preserve weight_quantizer.
292
+ """
293
+ quantize_config = (default_8bit_quantize_registry.
294
+ Default8BitQuantizeRegistry().
295
+ get_quantize_config(layer))
296
+ cluster_aware_quantize_config = super(
297
+ Default8bitClusterPreserveQuantizeRegistry,
298
+ self).apply_cluster_preserve_quantize_config(layer, quantize_config)
299
+
300
+ return cluster_aware_quantize_config
301
+
302
+
303
+ class ClusterPreserveDefaultWeightsQuantizer(quantizers.LastValueQuantizer):
304
+ """Quantize weights while preserving clusters."""
305
+
306
+ def __init__(
307
+ self, num_bits, per_axis, symmetric, narrow_range, preserve_sparsity):
308
+ """ClusterPreserveDefaultWeightsQuantizer.
309
+
310
+ Args:
311
+ num_bits: Number of bits for quantization
312
+ per_axis: Whether to apply per_axis quantization. The last dimension is
313
+ used as the axis.
314
+ symmetric: If true, use symmetric quantization limits instead of training
315
+ the minimum and maximum of each quantization range separately.
316
+ narrow_range: In case of 8 bits, narrow_range nudges the quantized range
317
+ to be [-127, 127] instead of [-128, 127]. This ensures symmetric
318
+ range has 0 as the centre.
319
+ preserve_sparsity: Whether to apply prune-cluster-preserving quantization
320
+ aware training.
321
+ """
322
+ super(ClusterPreserveDefaultWeightsQuantizer, self).__init__(
323
+ num_bits=num_bits,
324
+ per_axis=per_axis,
325
+ symmetric=symmetric,
326
+ narrow_range=narrow_range,
327
+ )
328
+ self.preserve_sparsity = preserve_sparsity
329
+
330
+ def _build_clusters(self, name, layer):
331
+ """Extracts the cluster centroids and cluster indices.
332
+
333
+ Extracts cluster centroids and cluster indices from the pretrained
334
+ clustered model when the input layer is clustered.
335
+
336
+ Args:
337
+ name: Name of weights in layer.
338
+ layer: Quantization wrapped keras layer.
339
+ Returns:
340
+ A dictionary of the initial values of the
341
+ cluster centroids, cluster indices, original weights,
342
+ the pretrained flag for marking the first training
343
+ epoch, and weight name.
344
+ """
345
+ result = {}
346
+ weights = getattr(layer.layer, name)
347
+ if self.preserve_sparsity and not tf.reduce_any(weights == 0):
348
+ self.preserve_sparsity = False
349
+ logging.warning(
350
+ 'Input layer does not contain zero weights, so apply CQAT instead.')
351
+ centroids_mask = None
352
+
353
+ # Detects whether layer is convolutional and is clustered per channel
354
+ data_format = getattr(layer.layer, 'data_format', None)
355
+ centroids, num_centroids, lookup, cluster_per_channel = get_centroids(
356
+ layer, weights, data_format)
357
+
358
+ if self.preserve_sparsity:
359
+ sparsity_mask = tf.math.divide_no_nan(weights, weights)
360
+ zero_idx = tf.argmin(tf.abs(centroids), axis=-1)
361
+ centroids_mask = 1.0 - tf.one_hot(zero_idx, num_centroids)
362
+ result = {SPARSITY_MASK: sparsity_mask}
363
+
364
+ # Prepare clustering variables for the Keras graph when clusters
365
+ # exist, assuming we do not use number_of_clusters larger than 1024
366
+ if num_centroids > 1024:
367
+ warnings.warn(f'No clustering performed on layer {layer.name}.\n'
368
+ f'Too many centroids to cluster.')
369
+ return result
370
+ # If not enough clusters, we do not preserve clustering
371
+ elif num_centroids <= 1:
372
+ warnings.warn(f'No clustering performed on layer {layer.name}.\n'
373
+ f'Perhaps too many clusters requested for this layer?')
374
+ return result
375
+ else:
376
+ clst_centroids_tf = layer.add_weight(
377
+ CLUSTER_CENTROIDS,
378
+ shape=centroids.shape,
379
+ initializer=keras.initializers.Constant(
380
+ value=K.batch_get_value([centroids])[0]
381
+ ),
382
+ dtype=centroids.dtype,
383
+ trainable=True,
384
+ )
385
+
386
+ ori_weights_tf = layer.add_weight(
387
+ ORIGINAL_WEIGHTS,
388
+ shape=weights.shape,
389
+ initializer=keras.initializers.Constant(
390
+ value=K.batch_get_value([weights])[0]
391
+ ),
392
+ dtype=weights.dtype,
393
+ trainable=True,
394
+ )
395
+
396
+ # Get clustering implementation according to layer type
397
+ clustering_impl_cls = clustering_registry.ClusteringLookupRegistry(
398
+ ).get_clustering_impl(
399
+ layer.layer, name, cluster_per_channel=cluster_per_channel)
400
+ clustering_impl = clustering_impl_cls(
401
+ clst_centroids_tf, cluster_config.GradientAggregation.SUM,
402
+ data_format)
403
+
404
+ pulling_indices = tf.dtypes.cast(
405
+ clustering_impl.get_pulling_indices(ori_weights_tf),
406
+ lookup.dtype
407
+ )
408
+
409
+ pulling_indices_tf = layer.add_weight(
410
+ PULLING_INDICES,
411
+ shape=lookup.shape,
412
+ initializer=keras.initializers.Constant(
413
+ value=K.batch_get_value([pulling_indices])[0]
414
+ ),
415
+ dtype=lookup.dtype,
416
+ trainable=False,
417
+ )
418
+
419
+ result_clst = {
420
+ CLUSTER_CENTROIDS: clst_centroids_tf,
421
+ PULLING_INDICES: pulling_indices_tf,
422
+ ORIGINAL_WEIGHTS: ori_weights_tf,
423
+ WEIGHT_NAME: name,
424
+ CLUSTERING_IMPL: clustering_impl,
425
+ CENTROIDS_MASK: centroids_mask,
426
+ }
427
+ result.update(result_clst)
428
+ return result
429
+
430
+ def build(self, tensor_shape, name, layer):
431
+ """Build (P)CQAT wrapper.
432
+
433
+ When preserve_sparsity is true and the input is clustered.
434
+
435
+ Args:
436
+ tensor_shape: Shape of weights which needs to be quantized.
437
+ name: Name of weights in layer.
438
+ layer: Quantization wrapped keras layer.
439
+ Returns:
440
+ Dictionary of centroids, indices and
441
+ quantization params, the dictionary will be passed
442
+ to __call__ function.
443
+ """
444
+ # To get all the initial values from pretrained clustered model
445
+ result = self._build_clusters(name, layer)
446
+ # Result can have clustering nodes, then this is CQAT
447
+ # Result can have both clustering nodes and sparsity mask, then
448
+ # this will be PCQAT
449
+ result.update(
450
+ super(ClusterPreserveDefaultWeightsQuantizer,
451
+ self).build(tensor_shape, name, layer))
452
+
453
+ return result
454
+
455
+ def __call__(self, inputs, training, weights, **kwargs):
456
+ """Apply cluster preserved quantization to the input tensor.
457
+
458
+ Args:
459
+ inputs: Input tensor (layer's weights) to be quantized.
460
+ training: Whether the graph is currently training.
461
+ weights: Dictionary of weights (params) the quantizer can use to
462
+ quantize the tensor (layer's weights). This contains the weights
463
+ created in the `build` function.
464
+ **kwargs: Additional variables which may be passed to the quantizer.
465
+ Returns:
466
+ quantized tensor.
467
+ """
468
+ if training:
469
+ if CLUSTER_CENTROIDS in weights:
470
+ if self.preserve_sparsity:
471
+ weights[ORIGINAL_WEIGHTS].assign(
472
+ tf.multiply(weights[ORIGINAL_WEIGHTS],
473
+ weights[SPARSITY_MASK]))
474
+ weights[CLUSTERING_IMPL].cluster_centroids.assign(
475
+ weights[CLUSTERING_IMPL].
476
+ cluster_centroids * weights[CENTROIDS_MASK]
477
+ )
478
+ weights[CLUSTER_CENTROIDS].assign(
479
+ weights[CLUSTERING_IMPL].cluster_centroids
480
+ )
481
+ # Insert clustering variables
482
+ weights[PULLING_INDICES].assign(tf.dtypes.cast(
483
+ weights[CLUSTERING_IMPL].get_pulling_indices(
484
+ weights[ORIGINAL_WEIGHTS]),
485
+ weights[PULLING_INDICES].dtype
486
+ ))
487
+
488
+ output = weights[CLUSTERING_IMPL].get_clustered_weight(
489
+ weights[PULLING_INDICES], weights[ORIGINAL_WEIGHTS])
490
+ inputs.assign(output)
491
+ else:
492
+ if self.preserve_sparsity:
493
+ inputs = tf.multiply(inputs, weights[SPARSITY_MASK])
494
+ output = inputs
495
+ else:
496
+ output = inputs
497
+
498
+ return quant_ops.LastValueQuantize(
499
+ output,
500
+ weights['min_var'],
501
+ weights['max_var'],
502
+ is_training=training,
503
+ num_bits=self.num_bits,
504
+ per_channel=self.per_axis,
505
+ symmetric=self.symmetric,
506
+ narrow_range=self.narrow_range
507
+ )
508
+
509
+
510
+ class ClusterPreserveDefault8BitWeightsQuantizer(
511
+ ClusterPreserveDefaultWeightsQuantizer):
512
+ """ClusterPreserveWeightsQuantizer for default 8bit weights."""
513
+
514
+ def __init__(self, preserve_sparsity):
515
+ super(ClusterPreserveDefault8BitWeightsQuantizer,
516
+ self).__init__(num_bits=8,
517
+ per_axis=False,
518
+ symmetric=True,
519
+ narrow_range=True,
520
+ preserve_sparsity=preserve_sparsity)
521
+ self.preserve_sparsity = preserve_sparsity
522
+
523
+
524
+ class ClusterPreserveDefault8BitConvWeightsQuantizer(
525
+ ClusterPreserveDefaultWeightsQuantizer,
526
+ default_8bit_quantizers.Default8BitConvWeightsQuantizer):
527
+ """ClusterPreserveWeightsQuantizer for default 8bit Conv2D weights."""
528
+
529
+ def __init__(self, preserve_sparsity): # pylint: disable=super-init-not-called
530
+ default_8bit_quantizers.Default8BitConvWeightsQuantizer.__init__(self)
531
+ self.preserve_sparsity = preserve_sparsity
532
+
533
+ def build(self, tensor_shape, name, layer):
534
+ result = ClusterPreserveDefaultWeightsQuantizer._build_clusters(
535
+ self, name, layer)
536
+ result.update(
537
+ default_8bit_quantizers.Default8BitConvWeightsQuantizer.build(
538
+ self, tensor_shape, name, layer))
539
+ return result
cluster_preserve_quantize_registry_test.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the 'License');
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an 'AS IS' BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for ClusterPreserveQuantizeRegistry."""
16
+
17
+ import tensorflow as tf
18
+
19
+ from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
20
+ from tensorflow_model_optimization.python.core.keras.compat import keras
21
+ from tensorflow_model_optimization.python.core.quantization.keras import quantize_config
22
+ from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve import cluster_preserve_quantize_registry
23
+ from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
24
+
25
+
26
+ QuantizeConfig = quantize_config.QuantizeConfig
27
+ layers = keras.layers
28
+
29
+
30
+ class ClusterPreserveQuantizeRegistryTest(tf.test.TestCase):
31
+
32
+ def setUp(self):
33
+ super(ClusterPreserveQuantizeRegistryTest, self).setUp()
34
+ # Test CQAT by default
35
+ self.cluster_preserve_quantize_registry = (
36
+ cluster_preserve_quantize_registry.ClusterPreserveQuantizeRegistry(
37
+ False)
38
+ )
39
+ # layers which are supported
40
+ # initial and build a Conv2D layer
41
+ self.layer_conv2d = layers.Conv2D(10, (2, 2))
42
+ self.layer_conv2d.build((2, 2))
43
+ # initial and build a Dense layer
44
+ self.layer_dense = layers.Dense(10)
45
+ self.layer_dense.build((2, 2))
46
+ # initial and build a ReLU layer
47
+ self.layer_relu = layers.ReLU()
48
+ self.layer_relu.build((2, 2))
49
+
50
+ # a layer which is not supported
51
+ # initial and build a Custom layer
52
+ self.layer_custom = self.CustomLayer()
53
+ self.layer_custom.build()
54
+
55
+ class CustomLayer(layers.Layer):
56
+ """A simple custom layer with training weights."""
57
+
58
+ def build(self, input_shape=(2, 2)):
59
+ self.add_weight(shape=input_shape,
60
+ initializer='random_normal',
61
+ trainable=True)
62
+
63
+ class CustomQuantizeConfig(QuantizeConfig):
64
+ """A dummy concrete class for testing unregistered configs."""
65
+
66
+ def get_weights_and_quantizers(self, layer):
67
+ return []
68
+
69
+ def get_activations_and_quantizers(self, layer):
70
+ return []
71
+
72
+ def set_quantize_weights(self, layer, quantize_weights):
73
+ pass
74
+
75
+ def set_quantize_activations(self, layer, quantize_activations):
76
+ pass
77
+
78
+ def get_output_quantizers(self, layer):
79
+ return []
80
+
81
+ def get_config(self):
82
+ return {}
83
+
84
+ def testSupportsKerasLayer(self):
85
+ # test registered layer
86
+ self.assertTrue(
87
+ self.cluster_preserve_quantize_registry.supports(self.layer_dense))
88
+ self.assertTrue(
89
+ self.cluster_preserve_quantize_registry.supports(self.layer_conv2d))
90
+ # test layer without training weights
91
+ self.assertTrue(
92
+ self.cluster_preserve_quantize_registry.supports(self.layer_relu))
93
+
94
+ def testDoesNotSupportCustomLayer(self):
95
+ self.assertFalse(
96
+ self.cluster_preserve_quantize_registry.supports(self.layer_custom))
97
+
98
+ def testApplyClusterPreserveWithQuantizeConfig(self):
99
+ (self.cluster_preserve_quantize_registry
100
+ .apply_cluster_preserve_quantize_config(
101
+ self.layer_conv2d,
102
+ default_8bit_quantize_registry.Default8BitConvQuantizeConfig(
103
+ ['kernel'], ['activation'], False)))
104
+
105
+ def testRaisesErrorUnsupportedQuantizeConfigWithLayer(self):
106
+ with self.assertRaises(
107
+ ValueError, msg='Unregistered QuantizeConfigs should raise error.'):
108
+ (self.cluster_preserve_quantize_registry.
109
+ apply_cluster_preserve_quantize_config(
110
+ self.layer_conv2d, self.CustomQuantizeConfig))
111
+
112
+ with self.assertRaises(ValueError,
113
+ msg='Unregistered layers should raise error.'):
114
+ (self.cluster_preserve_quantize_registry.
115
+ apply_cluster_preserve_quantize_config(
116
+ self.layer_custom, self.CustomQuantizeConfig))
117
+
118
+
119
+ class ClusterPreserveDefault8bitQuantizeRegistryTest(tf.test.TestCase):
120
+
121
+ def setUp(self):
122
+ super(ClusterPreserveDefault8bitQuantizeRegistryTest, self).setUp()
123
+ self.default_8bit_quantize_registry = (
124
+ default_8bit_quantize_registry.Default8BitQuantizeRegistry())
125
+ self.cluster_registry = clustering_registry.ClusteringRegistry()
126
+ # Test CQAT by default
127
+ self.cluster_preserve_quantize_registry = (
128
+ cluster_preserve_quantize_registry.ClusterPreserveQuantizeRegistry(
129
+ False))
130
+
131
+ def testSupportsClusterDefault8bitQuantizeKerasLayers(self):
132
+ # ClusterPreserveQuantize supported layer, must be suppoted
133
+ # by both Cluster and Quantize
134
+ cqat_layers_config_map = (
135
+ self.cluster_preserve_quantize_registry._LAYERS_CONFIG_MAP)
136
+ for cqat_support_layer in cqat_layers_config_map:
137
+ if cqat_layers_config_map[cqat_support_layer].weight_attrs and (
138
+ cqat_layers_config_map[cqat_support_layer].quantize_config_attrs):
139
+ self.assertIn(
140
+ cqat_support_layer, self.cluster_registry._LAYERS_WEIGHTS_MAP,
141
+ msg='Clusteirng doesn\'t support {}'.format(cqat_support_layer))
142
+ self.assertIn(
143
+ cqat_support_layer,
144
+ self.default_8bit_quantize_registry._layer_quantize_map,
145
+ msg='Default 8bit QAT doesn\'t support {}'.format(
146
+ cqat_support_layer))
147
+
148
+
149
+ if __name__ == '__main__':
150
+ tf.test.main()
collaborative_optimization.png ADDED
collaborative_optimization_dist.png ADDED
cripto.jpg ADDED
deep_crypto.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from predictors.btc_ltsm import BtcLtsm
3
+
4
+ if __name__ == "__main__":
5
+ parser = argparse.ArgumentParser(description='BTC Price Prediction')
6
+ parser.add_argument('--update', action='store_true', help='Update the dataset')
7
+ parser.add_argument('--train', action='store_true', help='Train the model')
8
+ parser.add_argument('--test', action='store_true', help='Test the model')
9
+ args = parser.parse_args()
10
+
11
+ btc_ltsm = BtcLtsm()
12
+ if args.update:
13
+ btc_ltsm.update_dataset()
14
+ if args.train:
15
+ btc_ltsm.train()
16
+ if args.test:
17
+ btc_ltsm.load()
18
+ btc_ltsm.test_model()
default_n_bit_transforms.py ADDED
@@ -0,0 +1,825 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Default 8-bit transforms."""
16
+
17
+ import collections
18
+ import inspect
19
+
20
+ import numpy as np
21
+ import tensorflow as tf
22
+
23
+ from tensorflow_model_optimization.python.core.keras.compat import keras
24
+ from tensorflow_model_optimization.python.core.keras.compat import unique_object_name
25
+ from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
26
+ from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
27
+ from tensorflow_model_optimization.python.core.quantization.keras import quantizers
28
+ from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
29
+ from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_configs as configs
30
+ from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry
31
+ from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms
32
+
33
+
34
+ LayerNode = transforms.LayerNode
35
+ LayerPattern = transforms.LayerPattern
36
+
37
+
38
+ def _get_conv_bn_layers(bn_layer_node):
39
+ bn_layer = bn_layer_node.layer
40
+ conv_layer = bn_layer_node.input_layers[0].layer
41
+ return conv_layer, bn_layer
42
+
43
+
44
+ def _get_weights(bn_layer_node):
45
+ """Returns weight values for fused layer, including copying original values in unfused version."""
46
+
47
+ return collections.OrderedDict(
48
+ list(bn_layer_node.input_layers[0].weights.items())
49
+ + list(bn_layer_node.weights.items()))
50
+
51
+
52
+ def _get_params(conv_layer, bn_layer, relu_layer=None):
53
+ """Retrieve conv_bn params within wrapped layers."""
54
+ if 'use_bias' in conv_layer['config']:
55
+ if conv_layer['config']['use_bias']:
56
+ raise ValueError(
57
+ 'use_bias should not be set to True in a Conv layer when followed '
58
+ 'by BatchNormalization. The bias in the Conv would be redundant '
59
+ 'with the one in the BatchNormalization.')
60
+
61
+ del conv_layer['config']['use_bias']
62
+
63
+ if 'name' in bn_layer['config']:
64
+ del bn_layer['config']['name']
65
+
66
+ # TODO(pulkitb): remove key conflicts
67
+ params = dict(
68
+ list(conv_layer['config'].items()) + list(bn_layer['config'].items()))
69
+
70
+ if relu_layer is not None:
71
+ params['post_activation'] = quantize_utils.deserialize_layer(
72
+ relu_layer, use_legacy_format=True
73
+ )
74
+
75
+ return params
76
+
77
+
78
+ def _get_layer_node(fused_layer, weights):
79
+ layer_config = quantize_utils.serialize_layer(
80
+ fused_layer, use_legacy_format=True
81
+ )
82
+ layer_config['name'] = layer_config['config']['name']
83
+ # This config tracks which layers get quantized, and whether they have a
84
+ # custom QuantizeConfig.
85
+ layer_metadata = {'quantize_config': None}
86
+
87
+ return LayerNode(layer_config, weights, metadata=layer_metadata)
88
+
89
+
90
+ def _get_quantize_config(layer_node):
91
+ return layer_node.metadata.get('quantize_config')
92
+
93
+
94
+ def _has_custom_quantize_config(*layer_nodes):
95
+ for layer_node in layer_nodes:
96
+ if _get_quantize_config(layer_node) is not None:
97
+ return True
98
+ return False
99
+
100
+
101
+ def _normalize_tuple(value):
102
+ if isinstance(value, int):
103
+ return (value,)
104
+ else:
105
+ return tuple(value)
106
+
107
+
108
+ class Conv2DBatchNormQuantize(transforms.Transform):
109
+ """Ensure FQ does not get placed between Conv and BatchNorm."""
110
+
111
+ def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
112
+ self._num_bits_weight = num_bits_weight
113
+ self._num_bits_activation = num_bits_activation
114
+
115
+ def pattern(self):
116
+ return LayerPattern(
117
+ 'BatchNormalization|SyncBatchNormalization',
118
+ inputs=[LayerPattern(
119
+ 'Conv2D|DepthwiseConv2D', config={'activation': 'linear'})])
120
+
121
+ def _replace(self, bn_layer_node, conv_layer_node):
122
+ if _has_custom_quantize_config(bn_layer_node, conv_layer_node):
123
+ return bn_layer_node
124
+
125
+ conv_layer_node.layer['config']['activation'] = (
126
+ quantize_utils.serialize_activation(
127
+ quantize_aware_activation.NoOpActivation(), use_legacy_format=True
128
+ )
129
+ )
130
+ bn_layer_node.metadata['quantize_config'] = (
131
+ configs.DefaultNBitOutputQuantizeConfig(
132
+ num_bits_weight=self._num_bits_weight,
133
+ num_bits_activation=self._num_bits_activation))
134
+
135
+ return bn_layer_node
136
+
137
+ def replacement(self, match_layer):
138
+ bn_layer_node = match_layer
139
+ conv_layer_node = match_layer.input_layers[0]
140
+
141
+ return self._replace(bn_layer_node, conv_layer_node)
142
+
143
+ def custom_objects(self):
144
+ return {
145
+ 'NoOpQuantizeConfig':
146
+ configs.NoOpQuantizeConfig,
147
+ 'NoOpActivation':
148
+ quantize_aware_activation.NoOpActivation
149
+ }
150
+
151
+
152
+ class Conv2DReshapeBatchNormQuantize(Conv2DBatchNormQuantize):
153
+ """Ensure FQ does not get placed between Conv, Reshape and BatchNorm."""
154
+
155
+ def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
156
+ super(Conv2DReshapeBatchNormQuantize, self).__init__(
157
+ num_bits_weight=num_bits_weight,
158
+ num_bits_activation=num_bits_activation)
159
+ self._num_bits_weight = num_bits_weight
160
+ self._num_bits_activation = num_bits_activation
161
+
162
+ def pattern(self):
163
+ return LayerPattern(
164
+ 'BatchNormalization|SyncBatchNormalization',
165
+ inputs=[LayerPattern(
166
+ 'Lambda', config={'name': 'sepconv1d_squeeze.*'},
167
+ inputs=[LayerPattern(
168
+ 'Conv2D|DepthwiseConv2D',
169
+ config={'activation': 'linear'})])])
170
+
171
+ def replacement(self, match_layer):
172
+ bn_layer_node = match_layer
173
+ reshape_layer_node = bn_layer_node.input_layers[0]
174
+ conv_layer_node = reshape_layer_node.input_layers[0]
175
+
176
+ return self._replace(bn_layer_node, conv_layer_node)
177
+
178
+
179
+ class Conv2DBatchNormReLUQuantize(Conv2DBatchNormQuantize):
180
+ """Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
181
+
182
+ def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
183
+ super(Conv2DBatchNormReLUQuantize, self).__init__(
184
+ num_bits_weight=num_bits_weight,
185
+ num_bits_activation=num_bits_activation)
186
+ self._num_bits_weight = num_bits_weight
187
+ self._num_bits_activation = num_bits_activation
188
+
189
+ def pattern(self):
190
+ return LayerPattern(
191
+ # TODO(pulkitb): Enhance match to only occur for relu, relu1 and relu6
192
+ 'ReLU',
193
+ inputs=[super(Conv2DBatchNormReLUQuantize, self).pattern()])
194
+
195
+ def _replace(self, relu_layer_node, bn_layer_node, conv_layer_node):
196
+ if _has_custom_quantize_config(
197
+ relu_layer_node, bn_layer_node, conv_layer_node):
198
+ return relu_layer_node
199
+
200
+ conv_layer_node.layer['config']['activation'] = (
201
+ quantize_utils.serialize_activation(
202
+ quantize_aware_activation.NoOpActivation(), use_legacy_format=True
203
+ )
204
+ )
205
+ bn_layer_node.metadata['quantize_config'] = (
206
+ configs.NoOpQuantizeConfig())
207
+
208
+ return relu_layer_node
209
+
210
+ def replacement(self, match_layer):
211
+ relu_layer_node = match_layer
212
+ bn_layer_node = relu_layer_node.input_layers[0]
213
+ conv_layer_node = bn_layer_node.input_layers[0]
214
+
215
+ return self._replace(relu_layer_node, bn_layer_node, conv_layer_node)
216
+
217
+
218
+ class Conv2DBatchNormActivationQuantize(Conv2DBatchNormReLUQuantize):
219
+ """Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
220
+
221
+ def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
222
+ super(Conv2DBatchNormActivationQuantize, self).__init__(
223
+ num_bits_weight=num_bits_weight,
224
+ num_bits_activation=num_bits_activation)
225
+ self._num_bits_weight = num_bits_weight
226
+ self._num_bits_activation = num_bits_activation
227
+
228
+ def pattern(self):
229
+ return LayerPattern(
230
+ 'Activation',
231
+ config={'activation': 'relu'},
232
+ inputs=[Conv2DBatchNormQuantize.pattern(self)])
233
+
234
+
235
+ class Conv2DReshapeBatchNormReLUQuantize(Conv2DBatchNormReLUQuantize):
236
+ """Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
237
+
238
+ def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
239
+ super(Conv2DReshapeBatchNormReLUQuantize, self).__init__(
240
+ num_bits_weight=num_bits_weight,
241
+ num_bits_activation=num_bits_activation)
242
+ self._num_bits_weight = num_bits_weight
243
+ self._num_bits_activation = num_bits_activation
244
+
245
+ def pattern(self):
246
+ return LayerPattern(
247
+ 'ReLU',
248
+ inputs=[Conv2DReshapeBatchNormQuantize.pattern(self)])
249
+
250
+ def replacement(self, match_layer):
251
+ relu_layer_node = match_layer
252
+ bn_layer_node = relu_layer_node.input_layers[0]
253
+ squeeze_layer_node = bn_layer_node.input_layers[0]
254
+ conv_layer_node = squeeze_layer_node.input_layers[0]
255
+
256
+ return self._replace(relu_layer_node, bn_layer_node, conv_layer_node)
257
+
258
+
259
+ class Conv2DReshapeBatchNormActivationQuantize(
260
+ Conv2DReshapeBatchNormReLUQuantize):
261
+ """Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
262
+
263
+ def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
264
+ super(Conv2DReshapeBatchNormActivationQuantize, self).__init__(
265
+ num_bits_weight=num_bits_weight,
266
+ num_bits_activation=num_bits_activation)
267
+ self._num_bits_weight = num_bits_weight
268
+ self._num_bits_activation = num_bits_activation
269
+
270
+ def pattern(self):
271
+ return LayerPattern(
272
+ 'Activation',
273
+ config={'activation': 'relu'},
274
+ inputs=[Conv2DReshapeBatchNormQuantize.pattern(self)])
275
+
276
+
277
+ class DenseBatchNormQuantize(transforms.Transform):
278
+ """Transform to be applied to "Dense"+ "BatchNorm" Graph.
279
+
280
+ This transform disables Quantization between Dense and BatchNorm
281
+ to ensure FQ does not get placed between them.
282
+ """
283
+
284
+ def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
285
+ self._num_bits_weight = num_bits_weight
286
+ self._num_bits_activation = num_bits_activation
287
+
288
+ def pattern(self):
289
+ return LayerPattern(
290
+ 'BatchNormalization|SyncBatchNormalization',
291
+ inputs=[LayerPattern('Dense', config={'activation': 'linear'})])
292
+
293
+ def _replace(self, bn_layer_node, dense_layer_node):
294
+ if _has_custom_quantize_config(bn_layer_node, dense_layer_node):
295
+ return bn_layer_node
296
+
297
+ dense_layer_node.layer['config']['activation'] = (
298
+ quantize_utils.serialize_activation(
299
+ quantize_aware_activation.NoOpActivation(), use_legacy_format=True
300
+ )
301
+ )
302
+ bn_layer_node.metadata['quantize_config'] = (
303
+ configs.DefaultNBitOutputQuantizeConfig(
304
+ num_bits_weight=self._num_bits_weight,
305
+ num_bits_activation=self._num_bits_activation))
306
+ return bn_layer_node
307
+
308
+ def replacement(self, match_layer):
309
+ bn_layer_node = match_layer
310
+ dense_layer_node = match_layer.input_layers[0]
311
+
312
+ return self._replace(bn_layer_node, dense_layer_node)
313
+
314
+ def custom_objects(self):
315
+ return {
316
+ 'DefaultNBitOutputQuantizeConfig':
317
+ configs.DefaultNBitOutputQuantizeConfig,
318
+ 'NoOpQuantizeConfig':
319
+ configs.NoOpQuantizeConfig,
320
+ 'NoOpActivation': quantize_aware_activation.NoOpActivation
321
+ }
322
+
323
+
324
+ class DenseBatchNormReLUQuantize(DenseBatchNormQuantize):
325
+ """Transform to be applied to "Dense"+ "BatchNorm" + "ReLU" Graph.
326
+
327
+ This transform disables Quantization between Dense, BatchNorm and ReLU
328
+ to ensure FQ does not get placed between them.
329
+ """
330
+
331
+ def pattern(self):
332
+ return LayerPattern(
333
+ 'ReLU', inputs=[super(DenseBatchNormReLUQuantize, self).pattern()])
334
+
335
+ def _replace(self, relu_layer_node, bn_layer_node, dense_layer_node):
336
+ if _has_custom_quantize_config(relu_layer_node, bn_layer_node,
337
+ dense_layer_node):
338
+ return relu_layer_node
339
+
340
+ dense_layer_node.layer['config']['activation'] = (
341
+ quantize_utils.serialize_activation(
342
+ quantize_aware_activation.NoOpActivation(), use_legacy_format=True
343
+ )
344
+ )
345
+ bn_layer_node.metadata['quantize_config'] = (
346
+ configs.NoOpQuantizeConfig())
347
+
348
+ return relu_layer_node
349
+
350
+ def replacement(self, match_layer):
351
+ relu_layer_node = match_layer
352
+ bn_layer_node = relu_layer_node.input_layers[0]
353
+ dense_layer_node = bn_layer_node.input_layers[0]
354
+
355
+ return self._replace(relu_layer_node, bn_layer_node, dense_layer_node)
356
+
357
+
358
+ class DenseBatchNormActivationQuantize(DenseBatchNormReLUQuantize):
359
+ """Transform to be applied to "Dense"+ "BatchNorm" + "ReLU" Graph.
360
+
361
+ This transform disables Quantization between Dense, BatchNorm and ReLU
362
+ to ensure FQ does not get placed between them.
363
+ """
364
+
365
+ def pattern(self):
366
+ return LayerPattern(
367
+ 'Activation',
368
+ config={'activation': 'relu'},
369
+ inputs=[DenseBatchNormQuantize.pattern(self)])
370
+
371
+
372
+ class SeparableConv1DQuantize(transforms.Transform):
373
+ """Add QAT support for Keras SeparableConv1D layer.
374
+
375
+ Transforms SeparableConv1D into a SeparableConv2D invocation. The Keras
376
+ SeparableConv1D layer internally uses the same code as a SeparbaleConv2D
377
+ layer. It simple expands and squeezes the tensor dimensions before and after
378
+ the convolutions. Applying this transform ensures the QAT handling for
379
+ SeparableConv2D kicks in and handles the FQ placement properly.
380
+
381
+ Maps:
382
+ Input -> SeparableConv1D -> Output
383
+ to
384
+ Input -> Lambda(ExpandDims) -> SeparableConv2D -> Lambda(Squeeze) -> Output
385
+
386
+ Unlike SeparableConv2DQuantize, this does not break the layer into
387
+ DepthwiseConv and Conv separately, since no DepthwiseConv1D exists.
388
+ """
389
+
390
+ def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
391
+ self._num_bits_weight = num_bits_weight
392
+ self._num_bits_activation = num_bits_activation
393
+
394
+ def pattern(self):
395
+ return LayerPattern('SeparableConv1D')
396
+
397
+ def _get_name(self, prefix):
398
+ # TODO(pulkitb): Move away from `unique_object_name` since it isn't
399
+ # exposed as externally usable.
400
+ return unique_object_name(prefix)
401
+
402
+ def replacement(self, match_layer):
403
+ if _has_custom_quantize_config(match_layer):
404
+ return match_layer
405
+
406
+ sepconv1d_layer = match_layer.layer
407
+ sepconv1d_config = sepconv1d_layer['config']
408
+ sepconv1d_weights = list(match_layer.weights.values())
409
+
410
+ padding = sepconv1d_config['padding']
411
+ # SepConv2D does not accept causal padding, and SepConv1D has some special
412
+ # handling for it.
413
+ # TODO(pulkitb): Add support for causal padding.
414
+ if padding == 'causal':
415
+ raise ValueError('SeparableConv1D with causal padding is not supported.')
416
+
417
+ # TODO(pulkitb): Handle other base_layer args such as dtype, input_dim etc.
418
+
419
+ sepconv2d_layer = keras.layers.SeparableConv2D(
420
+ filters=sepconv1d_config['filters'],
421
+ kernel_size=(1,) + _normalize_tuple(sepconv1d_config['kernel_size']),
422
+ strides=_normalize_tuple(sepconv1d_config['strides']) * 2,
423
+ padding=padding,
424
+ data_format=sepconv1d_config['data_format'],
425
+ dilation_rate=(1,)
426
+ + _normalize_tuple(sepconv1d_config['dilation_rate']),
427
+ depth_multiplier=sepconv1d_config['depth_multiplier'],
428
+ activation=sepconv1d_config['activation'],
429
+ use_bias=sepconv1d_config['use_bias'],
430
+ depthwise_initializer=sepconv1d_config['depthwise_initializer'],
431
+ pointwise_initializer=sepconv1d_config['pointwise_initializer'],
432
+ bias_initializer=sepconv1d_config['bias_initializer'],
433
+ depthwise_regularizer=sepconv1d_config['depthwise_regularizer'],
434
+ pointwise_regularizer=sepconv1d_config['pointwise_regularizer'],
435
+ bias_regularizer=sepconv1d_config['bias_regularizer'],
436
+ activity_regularizer=sepconv1d_config['activity_regularizer'],
437
+ depthwise_constraint=sepconv1d_config['depthwise_constraint'],
438
+ pointwise_constraint=sepconv1d_config['pointwise_constraint'],
439
+ bias_constraint=sepconv1d_config['bias_constraint'],
440
+ # TODO(pulkitb): Rethink what to do for name. Using the same name leads
441
+ # to confusion, since it's typically separable_conv1d
442
+ name=sepconv1d_config['name'] + '_QAT_SepConv2D',
443
+ trainable=sepconv1d_config['trainable'],
444
+ )
445
+
446
+ sepconv2d_weights = collections.OrderedDict()
447
+ sepconv2d_weights['depthwise_kernel:0'] = np.expand_dims(
448
+ sepconv1d_weights[0], 0)
449
+ sepconv2d_weights['pointwise_kernel:0'] = np.expand_dims(
450
+ sepconv1d_weights[1], 0)
451
+ if sepconv1d_config['use_bias']:
452
+ sepconv2d_weights['bias:0'] = sepconv1d_weights[2]
453
+
454
+ if sepconv1d_config['data_format'] == 'channels_last':
455
+ spatial_dim = 1
456
+ else:
457
+ spatial_dim = 2
458
+
459
+ sepconv2d_layer_config = quantize_utils.serialize_layer(
460
+ sepconv2d_layer, use_legacy_format=True
461
+ )
462
+ sepconv2d_layer_config['name'] = sepconv2d_layer.name
463
+
464
+ # Needed to ensure these new layers are considered for quantization.
465
+ sepconv2d_metadata = {'quantize_config': None}
466
+
467
+ # TODO(pulkitb): Consider moving from Lambda to custom ExpandDims/Squeeze.
468
+
469
+ # Layer before SeparableConv2D which expands input tensors to match 2D.
470
+ expand_layer = keras.layers.Lambda(
471
+ lambda x: tf.expand_dims(x, spatial_dim),
472
+ name=self._get_name('sepconv1d_expand'),
473
+ )
474
+ expand_layer_config = quantize_utils.serialize_layer(
475
+ expand_layer, use_legacy_format=True
476
+ )
477
+ expand_layer_config['name'] = expand_layer.name
478
+ expand_layer_metadata = {
479
+ 'quantize_config':
480
+ configs.NoOpQuantizeConfig()}
481
+
482
+ squeeze_layer = keras.layers.Lambda(
483
+ lambda x: tf.squeeze(x, [spatial_dim]),
484
+ name=self._get_name('sepconv1d_squeeze'),
485
+ )
486
+ squeeze_layer_config = quantize_utils.serialize_layer(
487
+ squeeze_layer, use_legacy_format=True
488
+ )
489
+ squeeze_layer_config['name'] = squeeze_layer.name
490
+ squeeze_layer_metadata = {
491
+ 'quantize_config':
492
+ configs.NoOpQuantizeConfig()}
493
+
494
+ return LayerNode(
495
+ squeeze_layer_config,
496
+ metadata=squeeze_layer_metadata,
497
+ input_layers=[LayerNode(
498
+ sepconv2d_layer_config,
499
+ weights=sepconv2d_weights,
500
+ metadata=sepconv2d_metadata,
501
+ input_layers=[LayerNode(
502
+ expand_layer_config, metadata=expand_layer_metadata)]
503
+ )])
504
+
505
+
506
+ class SeparableConvQuantize(transforms.Transform):
507
+ """Break SeparableConv into a DepthwiseConv and Conv layer.
508
+
509
+ SeparableConv is a composition of a DepthwiseConv and a Conv layer. For the
510
+ purpose of quantization, a FQ operation needs to be placed between the output
511
+ of DepthwiseConv and the following Conv.
512
+
513
+ This is needed since there is a dynamic tensor in between the two layers, and
514
+ it's range information needs to be captured by the FakeQuant op to ensure
515
+ full int8 quantization of the layers is possible.
516
+
517
+ Splitting the layer into 2 ensures that each individual layer is handled
518
+ correctly with respect to quantization.
519
+ """
520
+
521
+ def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
522
+ self._num_bits_weight = num_bits_weight
523
+ self._num_bits_activation = num_bits_activation
524
+
525
+ def pattern(self):
526
+ return LayerPattern('SeparableConv2D')
527
+
528
+ def replacement(self, match_layer):
529
+ if _has_custom_quantize_config(match_layer):
530
+ return match_layer
531
+
532
+ sepconv_layer = match_layer.layer
533
+ sepconv_weights = list(match_layer.weights.values())
534
+
535
+ # TODO(pulkitb): SeparableConv has kwargs other than constructor args which
536
+ # need to be handled.
537
+ # Applicable to both layers: trainable, dtype, name
538
+ # Applicable to dconv: input_dim, input_shape, batch_input_shape, batch_size
539
+ # Needs special handling: weights
540
+ # Unknown: dynamic, autocast
541
+
542
+ dconv_layer = keras.layers.DepthwiseConv2D(
543
+ kernel_size=sepconv_layer['config']['kernel_size'],
544
+ strides=sepconv_layer['config']['strides'],
545
+ padding=sepconv_layer['config']['padding'],
546
+ depth_multiplier=sepconv_layer['config']['depth_multiplier'],
547
+ data_format=sepconv_layer['config']['data_format'],
548
+ dilation_rate=sepconv_layer['config']['dilation_rate'],
549
+ activation=None,
550
+ use_bias=False,
551
+ depthwise_initializer=sepconv_layer['config']['depthwise_initializer'],
552
+ depthwise_regularizer=sepconv_layer['config']['depthwise_regularizer'],
553
+ depthwise_constraint=sepconv_layer['config']['depthwise_constraint'],
554
+ trainable=sepconv_layer['config']['trainable'],
555
+ )
556
+ dconv_weights = collections.OrderedDict()
557
+ dconv_weights['depthwise_kernel:0'] = sepconv_weights[0]
558
+ dconv_layer_config = quantize_utils.serialize_layer(
559
+ dconv_layer, use_legacy_format=True
560
+ )
561
+ dconv_layer_config['name'] = dconv_layer.name
562
+ # Needed to ensure these new layers are considered for quantization.
563
+ dconv_metadata = {'quantize_config': None}
564
+
565
+ conv_layer = keras.layers.Conv2D(
566
+ filters=sepconv_layer['config']['filters'],
567
+ kernel_size=(1, 1), # (1,) * rank
568
+ strides=(1, 1),
569
+ padding='valid',
570
+ data_format=sepconv_layer['config']['data_format'],
571
+ dilation_rate=sepconv_layer['config']['dilation_rate'],
572
+ groups=1,
573
+ activation=sepconv_layer['config']['activation'],
574
+ use_bias=sepconv_layer['config']['use_bias'],
575
+ kernel_initializer=sepconv_layer['config']['pointwise_initializer'],
576
+ bias_initializer=sepconv_layer['config']['bias_initializer'],
577
+ kernel_regularizer=sepconv_layer['config']['pointwise_regularizer'],
578
+ bias_regularizer=sepconv_layer['config']['bias_regularizer'],
579
+ activity_regularizer=sepconv_layer['config']['activity_regularizer'],
580
+ kernel_constraint=sepconv_layer['config']['pointwise_constraint'],
581
+ bias_constraint=sepconv_layer['config']['bias_constraint'],
582
+ trainable=sepconv_layer['config']['trainable'],
583
+ )
584
+ conv_weights = collections.OrderedDict()
585
+ conv_weights['kernel:0'] = sepconv_weights[1]
586
+ if sepconv_layer['config']['use_bias']:
587
+ conv_weights['bias:0'] = sepconv_weights[2]
588
+ conv_layer_config = quantize_utils.serialize_layer(
589
+ conv_layer, use_legacy_format=True
590
+ )
591
+ conv_layer_config['name'] = conv_layer.name
592
+ # Needed to ensure these new layers are considered for quantization.
593
+ conv_metadata = {'quantize_config': None}
594
+
595
+ dconv_layer_node = LayerNode(
596
+ dconv_layer_config, weights=dconv_weights, metadata=dconv_metadata)
597
+ return LayerNode(
598
+ conv_layer_config,
599
+ weights=conv_weights,
600
+ input_layers=[dconv_layer_node],
601
+ metadata=conv_metadata)
602
+
603
+
604
+ class LayerReLUQuantize(transforms.Transform):
605
+ """Ensure FQ does not get placed between Add and ReLU."""
606
+
607
+ def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
608
+ self._num_bits_weight = num_bits_weight
609
+ self._num_bits_activation = num_bits_activation
610
+
611
+ def pattern(self):
612
+ return LayerPattern(
613
+ 'ReLU', inputs=[LayerPattern('Add|Conv2D|DepthwiseConv2D|Dense')])
614
+
615
+ def replacement(self, match_layer):
616
+ relu_layer_node = match_layer
617
+ add_layer_node = relu_layer_node.input_layers[0]
618
+
619
+ add_layer_node.metadata['quantize_config'] = (
620
+ configs.NoOpQuantizeConfig())
621
+
622
+ return match_layer
623
+
624
+ def custom_objects(self):
625
+ return {
626
+ 'NoOpQuantizeConfig':
627
+ configs.NoOpQuantizeConfig,
628
+ }
629
+
630
+
631
+ class LayerReluActivationQuantize(LayerReLUQuantize):
632
+ """Ensure FQ does not get placed between Add and ReLU."""
633
+
634
+ def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
635
+ super(LayerReluActivationQuantize, self).__init__(
636
+ num_bits_weight=num_bits_weight,
637
+ num_bits_activation=num_bits_activation)
638
+ self._num_bits_weight = num_bits_weight
639
+ self._num_bits_activation = num_bits_activation
640
+
641
+ def pattern(self):
642
+ return LayerPattern(
643
+ 'Activation',
644
+ config={'activation': 'relu'},
645
+ inputs=[LayerPattern('Add|Conv2D|DepthwiseConv2D|Dense')])
646
+
647
+
648
+ class InputLayerQuantize(transforms.Transform):
649
+ """Quantizes InputLayer, by adding QuantizeLayer after it.
650
+
651
+ InputLayer => InputLayer -> QuantizeLayer
652
+ """
653
+
654
+ def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
655
+ self._num_bits_weight = num_bits_weight
656
+ self._num_bits_activation = num_bits_activation
657
+
658
+ def pattern(self):
659
+ return LayerPattern('InputLayer')
660
+
661
+ def replacement(self, match_layer):
662
+ quant_layer = quantize_layer.QuantizeLayer(
663
+ quantizers.AllValuesQuantizer(
664
+ num_bits=self._num_bits_activation, per_axis=False,
665
+ symmetric=False, narrow_range=False)) # activation/output
666
+ layer_config = quantize_utils.serialize_layer(
667
+ quant_layer, use_legacy_format=True
668
+ )
669
+ layer_config['name'] = quant_layer.name
670
+
671
+ quant_layer_node = LayerNode(
672
+ layer_config,
673
+ input_layers=[match_layer])
674
+
675
+ return quant_layer_node
676
+
677
+ def custom_objects(self):
678
+ return {
679
+ 'QuantizeLayer': quantize_layer.QuantizeLayer,
680
+ 'MovingAverageQuantizer': quantizers.MovingAverageQuantizer,
681
+ 'AllValuesQuantizer': quantizers.AllValuesQuantizer
682
+ }
683
+
684
+
685
+ class ConcatTransform(transforms.Transform):
686
+ """Transform for Concatenate. Quantize only after concatenation."""
687
+
688
+ # pylint:disable=protected-access
689
+
690
+ def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
691
+ self._num_bits_weight = num_bits_weight
692
+ self._num_bits_activation = num_bits_activation
693
+
694
+ def pattern(self):
695
+ # TODO(pulkitb): Write a clean way to handle length patterns.
696
+ return LayerPattern(
697
+ 'Concatenate', inputs=[LayerPattern('.*'), LayerPattern('.*')])
698
+
699
+ def _get_layer_type(self, layer_class_name):
700
+ keras_layers = inspect.getmembers(keras.layers, inspect.isclass)
701
+ for layer_name, layer_type in keras_layers:
702
+ if layer_name == layer_class_name:
703
+ return layer_type
704
+ return None
705
+
706
+ def _disable_output_quantize(self, quantize_config):
707
+ # TODO(pulkitb): Disabling quantize_config may also require handling
708
+ # activation quantizers. Handle that properly.
709
+ quantize_config.get_output_quantizers = lambda layer: []
710
+
711
+ def replacement(self, match_layer):
712
+ concat_layer_node = match_layer
713
+ feeding_layer_nodes = match_layer.input_layers
714
+
715
+ default_registry = (
716
+ default_n_bit_quantize_registry.DefaultNBitQuantizeRegistry(
717
+ num_bits_weight=self._num_bits_weight,
718
+ num_bits_activation=self._num_bits_activation))
719
+
720
+ feed_quantize_configs = []
721
+ for feed_layer_node in feeding_layer_nodes:
722
+ quantize_config = feed_layer_node.metadata.get('quantize_config')
723
+ if not quantize_config:
724
+ layer_class = self._get_layer_type(feed_layer_node.layer['class_name'])
725
+ if layer_class is None:
726
+ # Concat has an input layer we don't recognize. Return.
727
+ return match_layer
728
+
729
+ if layer_class == keras.layers.Concatenate:
730
+ # Input layer to Concat is also Concat. Don't quantize it.
731
+ feed_layer_node.metadata['quantize_config'] = (
732
+ configs.NoOpQuantizeConfig())
733
+ continue
734
+
735
+ if not default_registry._is_supported_layer(layer_class):
736
+ # Feeding layer is not supported by registry
737
+ return match_layer
738
+
739
+ quantize_config = default_registry._get_quantize_config(layer_class)
740
+ feed_layer_node.metadata['quantize_config'] = quantize_config
741
+
742
+ feed_quantize_configs.append(quantize_config)
743
+
744
+ # TODO(pulkitb): this currently only disables output quantize config, but
745
+ # cannot properly handle if the FQ was added to the activation. Hand this
746
+ # properly.
747
+ for quantize_config in feed_quantize_configs:
748
+ self._disable_output_quantize(quantize_config)
749
+
750
+ if not concat_layer_node.metadata.get('quantize_config'):
751
+ concat_layer_node.metadata['quantize_config'] = (
752
+ configs.DefaultNBitOutputQuantizeConfig(
753
+ num_bits_weight=self._num_bits_weight,
754
+ num_bits_activation=self._num_bits_activation))
755
+
756
+ return concat_layer_node
757
+
758
+ # pylint:enable=protected-access
759
+
760
+
761
+ class ConcatTransform3Inputs(ConcatTransform):
762
+ """Transform for 3 inputs Concatenate."""
763
+
764
+ def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
765
+ super(ConcatTransform3Inputs, self).__init__(
766
+ num_bits_weight=num_bits_weight,
767
+ num_bits_activation=num_bits_activation)
768
+ self._num_bits_weight = num_bits_weight
769
+ self._num_bits_activation = num_bits_activation
770
+
771
+ def pattern(self):
772
+ return LayerPattern(
773
+ 'Concatenate',
774
+ inputs=[LayerPattern('.*'), LayerPattern('.*'), LayerPattern('.*')])
775
+
776
+
777
+ class ConcatTransform4Inputs(ConcatTransform):
778
+ """Transform for 4 inputs Concatenate."""
779
+
780
+ def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
781
+ super(ConcatTransform4Inputs, self).__init__(
782
+ num_bits_weight=num_bits_weight,
783
+ num_bits_activation=num_bits_activation)
784
+ self._num_bits_weight = num_bits_weight
785
+ self._num_bits_activation = num_bits_activation
786
+
787
+ def pattern(self):
788
+ return LayerPattern(
789
+ 'Concatenate',
790
+ inputs=[LayerPattern('.*'), LayerPattern('.*'), LayerPattern('.*'),
791
+ LayerPattern('.*')])
792
+
793
+
794
+ class ConcatTransform5Inputs(ConcatTransform):
795
+ """Transform for 5 inputs Concatenate."""
796
+
797
+ def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
798
+ super(ConcatTransform5Inputs, self).__init__(
799
+ num_bits_weight=num_bits_weight,
800
+ num_bits_activation=num_bits_activation)
801
+ self._num_bits_weight = num_bits_weight
802
+ self._num_bits_activation = num_bits_activation
803
+
804
+ def pattern(self):
805
+ return LayerPattern(
806
+ 'Concatenate',
807
+ inputs=[LayerPattern('.*'), LayerPattern('.*'), LayerPattern('.*'),
808
+ LayerPattern('.*'), LayerPattern('.*')])
809
+
810
+
811
+ class ConcatTransform6Inputs(ConcatTransform):
812
+ """Transform for 6 inputs Concatenate."""
813
+
814
+ def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
815
+ super(ConcatTransform6Inputs, self).__init__(
816
+ num_bits_weight=num_bits_weight,
817
+ num_bits_activation=num_bits_activation)
818
+ self._num_bits_weight = num_bits_weight
819
+ self._num_bits_activation = num_bits_activation
820
+
821
+ def pattern(self):
822
+ return LayerPattern(
823
+ 'Concatenate',
824
+ inputs=[LayerPattern('.*'), LayerPattern('.*'), LayerPattern('.*'),
825
+ LayerPattern('.*'), LayerPattern('.*'), LayerPattern('.*')])
main.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python main. py
2
+ Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
3
+ model. safetensors.index.json: 100%|
4
+ | 13.5k/13.5k [00:00‹?, PB/s]
5
+ model-00001-of-00002. safetensors: 100%
6
+ | 4.95G/4.95G [07:27<00:00, 11. 1MB/s]
7
+ model-00002-of-00002. safetensors: 100%
8
+ 67. 1M/67.1M [00:05<00:00, 11.5MB/s]
9
+ Downloading shards: 100% ||
10
+ | 2/2 [07:35‹00:00, 227.61s/it]
11
+ Gemma's activation function should be approximate GeLU and not exact GeLU. Changing the activation function to 'gelu_pytorch_tanh.if you want to use the legacy "gelu', edit the "model.config to
12
+ set hidden_activation=gelu*
13
+ instead of todden act
14
+ instead of hidden_act. See https://github.com/huggingface/transformers/pull/29402 for
15
+ more details.
16
+ Loading checkpoint shards: 100%|
17
+ | 2/2 [00:03<00:00, 1.87s/itl
18
+ generation_config json: 100%||
19
+ 137/137[00:00<?」3B/s]
20
+ nexa model result:
21
+ a pouto using the specified caea and resolutiou stones iption: rame rs a photo (cama a):)
22
+ Captures
23
+ - camera (str): Specifies the camera
24
+ to use. Can be \'front\' or \'back\'. The default is \'back\'. \n\n
25
+ Returns: \n
26
+ - str: The string contains the file
27
+ 2624 t 12 4a.
28
+ Photo if nees at ay 96 83662387968t, ample: /storage/emulated/o/Pictures/NAPP/3N
29
+ 123456.Jpg\'\n latency: 367.85967230796814
misc.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019, The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Misc."""
15
+
16
+ from __future__ import absolute_import
17
+ from __future__ import division
18
+ from __future__ import print_function
19
+
20
+ import collections
21
+ import tensorflow as tf
22
+
23
+ from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import encoding_stage
24
+
25
+
26
+ @encoding_stage.tf_style_encoding_stage
27
+ class SplitBySmallValueEncodingStage(encoding_stage.EncodingStageInterface):
28
+ """Encoding stage splitting the input by small values.
29
+
30
+ This encoding stage will split the input into two outputs: the value and the
31
+ indices of the elements whose absolute value is larger than a certain
32
+ threshold. The elements smaller than the threshold is then decoded to zero.
33
+ """
34
+
35
+ ENCODED_INDICES_KEY = 'indices'
36
+ ENCODED_VALUES_KEY = 'non_zero_floats'
37
+ THRESHOLD_PARAMS_KEY = 'threshold'
38
+
39
+ def __init__(self, threshold=1e-8):
40
+ """Initializer for the SplitBySmallValueEncodingStage.
41
+
42
+ Args:
43
+ threshold: The threshold of the small weights to be set to zero.
44
+ """
45
+ self._threshold = threshold
46
+
47
+ @property
48
+ def name(self):
49
+ """See base class."""
50
+ return 'split_by_small_value'
51
+
52
+ @property
53
+ def compressible_tensors_keys(self):
54
+ """See base class."""
55
+ return [
56
+ self.ENCODED_VALUES_KEY,
57
+ self.ENCODED_INDICES_KEY,
58
+ ]
59
+
60
+ @property
61
+ def commutes_with_sum(self):
62
+ """See base class."""
63
+ return False
64
+
65
+ @property
66
+ def decode_needs_input_shape(self):
67
+ """See base class."""
68
+ return True
69
+
70
+ def get_params(self):
71
+ """See base class."""
72
+ encode_params = collections.OrderedDict([(self.THRESHOLD_PARAMS_KEY,
73
+ self._threshold)])
74
+ decode_params = collections.OrderedDict()
75
+ return encode_params, decode_params
76
+
77
+ def encode(self, x, encode_params):
78
+ """See base class."""
79
+
80
+ threshold = tf.cast(encode_params[self.THRESHOLD_PARAMS_KEY], x.dtype)
81
+ indices = tf.cast(tf.compat.v2.where(tf.abs(x) > threshold), tf.int32)
82
+ non_zero_x = tf.gather_nd(x, indices)
83
+ indices = tf.squeeze(indices, axis=1)
84
+ return collections.OrderedDict([
85
+ (self.ENCODED_INDICES_KEY, indices),
86
+ (self.ENCODED_VALUES_KEY, non_zero_x),
87
+ ])
88
+
89
+ def decode(self,
90
+ encoded_tensors,
91
+ decode_params,
92
+ num_summands=None,
93
+ shape=None):
94
+ """See base class."""
95
+ del decode_params, num_summands # Unused.
96
+
97
+ indices = encoded_tensors[self.ENCODED_INDICES_KEY]
98
+ non_zero_x = encoded_tensors[self.ENCODED_VALUES_KEY]
99
+
100
+ indices = tf.expand_dims(indices, 1)
101
+
102
+ indices = tf.cast(indices, tf.int64)
103
+ shape = tf.cast(shape, tf.int64)
104
+ sparse_tensor = tf.SparseTensor(indices=indices, values=non_zero_x,
105
+ dense_shape=shape)
106
+ decoded_x = tf.sparse.to_dense(sparse_tensor)
107
+
108
+ return decoded_x
109
+
110
+
111
+ @encoding_stage.tf_style_encoding_stage
112
+ class DifferenceBetweenIntegersEncodingStage(
113
+ encoding_stage.EncodingStageInterface):
114
+ """Encoding stage taking the difference between a sequence of integers.
115
+
116
+ This encoding stage can be useful when the original integers can be large, but
117
+ the difference of the integers are much smaller values and have a more compact
118
+ representation. For example, it can be combined with the
119
+ `SplitBySmallValueEncodingStage` to further compress the increasing sequence
120
+ of indices.
121
+
122
+ The encode method expects a tensor with 1 dimension and with integer dtype.
123
+ """
124
+
125
+ ENCODED_VALUES_KEY = 'difference_between_integers'
126
+
127
+ @property
128
+ def name(self):
129
+ """See base class."""
130
+ return 'difference_between_integers'
131
+
132
+ @property
133
+ def compressible_tensors_keys(self):
134
+ """See base class."""
135
+ return [
136
+ self.ENCODED_VALUES_KEY,
137
+ ]
138
+
139
+ @property
140
+ def commutes_with_sum(self):
141
+ """See base class."""
142
+ return False
143
+
144
+ @property
145
+ def decode_needs_input_shape(self):
146
+ """See base class."""
147
+ return False
148
+
149
+ def get_params(self):
150
+ """See base class."""
151
+ return collections.OrderedDict(), collections.OrderedDict()
152
+
153
+ def encode(self, x, encode_params):
154
+ """See base class."""
155
+ del encode_params # Unused.
156
+ if x.shape.ndims != 1:
157
+ raise ValueError('Number of dimensions must be 1. Shape of x: %s' %
158
+ x.shape)
159
+ if not x.dtype.is_integer:
160
+ raise TypeError(
161
+ 'Unsupported input type: %s. Support only integer types.' % x.dtype)
162
+
163
+ diff_x = x - tf.concat([[0], x[:-1]], 0)
164
+ return collections.OrderedDict([(self.ENCODED_VALUES_KEY, diff_x)])
165
+
166
+ def decode(self,
167
+ encoded_tensors,
168
+ decode_params,
169
+ num_summands=None,
170
+ shape=None):
171
+ """See base class."""
172
+ del decode_params, num_summands, shape # Unused
173
+ return tf.cumsum(encoded_tensors[self.ENCODED_VALUES_KEY])
misc_test.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019, The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+
19
+ import itertools
20
+
21
+ from absl.testing import parameterized
22
+ import numpy as np
23
+ import tensorflow as tf
24
+
25
+ from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research import misc
26
+ from tensorflow_model_optimization.python.core.internal.tensor_encoding.testing import test_utils
27
+
28
+
29
+ if tf.executing_eagerly():
30
+ tf.compat.v1.disable_eager_execution()
31
+
32
+
33
+ class SplitBySmallValueEncodingStageTest(test_utils.BaseEncodingStageTest):
34
+
35
+ def default_encoding_stage(self):
36
+ """See base class."""
37
+ return misc.SplitBySmallValueEncodingStage()
38
+
39
+ def default_input(self):
40
+ """See base class."""
41
+ return tf.random.uniform([50], minval=-1.0, maxval=1.0)
42
+
43
+ @property
44
+ def is_lossless(self):
45
+ """See base class."""
46
+ return False
47
+
48
+ def common_asserts_for_test_data(self, data):
49
+ """See base class."""
50
+ self._assert_is_integer(
51
+ data.encoded_x[misc.SplitBySmallValueEncodingStage.ENCODED_INDICES_KEY])
52
+
53
+ def _assert_is_integer(self, indices):
54
+ """Asserts that indices values are integers."""
55
+ assert indices.dtype == np.int32
56
+
57
+ @parameterized.parameters([tf.float32, tf.float64])
58
+ def test_input_types(self, x_dtype):
59
+ # Tests different input dtypes.
60
+ x = tf.constant([1.0, 0.1, 0.01, 0.001, 0.0001], dtype=x_dtype)
61
+ threshold = 0.05
62
+ stage = misc.SplitBySmallValueEncodingStage(threshold=threshold)
63
+ encode_params, decode_params = stage.get_params()
64
+ encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
65
+ decode_params)
66
+ test_data = test_utils.TestData(x, encoded_x, decoded_x)
67
+ test_data = self.evaluate_test_data(test_data)
68
+
69
+ self._assert_is_integer(test_data.encoded_x[
70
+ misc.SplitBySmallValueEncodingStage.ENCODED_INDICES_KEY])
71
+
72
+ # The numpy arrays must have the same dtype as the arrays from test_data.
73
+ expected_encoded_values = np.array([1.0, 0.1], dtype=x.dtype.as_numpy_dtype)
74
+ expected_encoded_indices = np.array([0, 1], dtype=np.int32)
75
+ expected_decoded_x = np.array([1.0, 0.1, 0., 0., 0.],
76
+ dtype=x_dtype.as_numpy_dtype)
77
+ self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY],
78
+ expected_encoded_values)
79
+ self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY],
80
+ expected_encoded_indices)
81
+ self.assertAllEqual(test_data.decoded_x, expected_decoded_x)
82
+
83
+ def test_all_zero_input_works(self):
84
+ # Tests that encoding does not blow up with all-zero input. With all-zero
85
+ # input, both of the encoded values will be empty arrays.
86
+ stage = misc.SplitBySmallValueEncodingStage()
87
+ test_data = self.run_one_to_many_encode_decode(stage,
88
+ lambda: tf.zeros([50]))
89
+
90
+ self.assertAllEqual(np.zeros((50)).astype(np.float32), test_data.decoded_x)
91
+
92
+ def test_all_below_threshold_works(self):
93
+ # Tests that encoding does not blow up with all-below-threshold input. In
94
+ # this case, both of the encoded values will be empty arrays.
95
+ stage = misc.SplitBySmallValueEncodingStage(threshold=0.1)
96
+ x = tf.random.uniform([50], minval=-0.01, maxval=0.01)
97
+ encode_params, decode_params = stage.get_params()
98
+ encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
99
+ decode_params)
100
+ test_data = test_utils.TestData(x, encoded_x, decoded_x)
101
+ test_data = self.evaluate_test_data(test_data)
102
+
103
+ expected_encoded_indices = np.array([], dtype=np.int32).reshape([0])
104
+ self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY], [])
105
+ self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY],
106
+ expected_encoded_indices)
107
+ self.assertAllEqual(test_data.decoded_x,
108
+ np.zeros([50], dtype=x.dtype.as_numpy_dtype))
109
+
110
+
111
+ class DifferenceBetweenIntegersEncodingStageTest(
112
+ test_utils.BaseEncodingStageTest):
113
+
114
+ def default_encoding_stage(self):
115
+ """See base class."""
116
+ return misc.DifferenceBetweenIntegersEncodingStage()
117
+
118
+ def default_input(self):
119
+ """See base class."""
120
+ return tf.random.uniform([10], minval=0, maxval=10, dtype=tf.int64)
121
+
122
+ @property
123
+ def is_lossless(self):
124
+ """See base class."""
125
+ return True
126
+
127
+ def common_asserts_for_test_data(self, data):
128
+ """See base class."""
129
+ self.assertAllEqual(data.x, data.decoded_x)
130
+
131
+ @parameterized.parameters(
132
+ itertools.product([[1,], [2,], [10,]], [tf.int32, tf.int64]))
133
+ def test_with_multiple_input_shapes(self, input_dims, dtype):
134
+
135
+ def x_fn():
136
+ return tf.random.uniform(input_dims, minval=0, maxval=10, dtype=dtype)
137
+
138
+ test_data = self.run_one_to_many_encode_decode(
139
+ self.default_encoding_stage(), x_fn)
140
+ self.common_asserts_for_test_data(test_data)
141
+
142
+ def test_empty_input_static(self):
143
+ # Tests that the encoding works when the input shape is [0].
144
+ x = []
145
+ x = tf.convert_to_tensor(x, dtype=tf.int32)
146
+ assert x.shape.as_list() == [0]
147
+
148
+ stage = self.default_encoding_stage()
149
+ encode_params, decode_params = stage.get_params()
150
+ encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
151
+ decode_params)
152
+
153
+ test_data = self.evaluate_test_data(
154
+ test_utils.TestData(x, encoded_x, decoded_x))
155
+ self.common_asserts_for_test_data(test_data)
156
+
157
+ def test_empty_input_dynamic(self):
158
+ # Tests that the encoding works when the input shape is [0], but not
159
+ # statically known.
160
+ y = tf.zeros((10,))
161
+ indices = tf.compat.v2.where(tf.abs(y) > 1e-8)
162
+ x = tf.gather_nd(y, indices)
163
+ x = tf.cast(x, tf.int32) # Empty tensor.
164
+ assert x.shape.as_list() == [None]
165
+ stage = self.default_encoding_stage()
166
+ encode_params, decode_params = stage.get_params()
167
+ encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
168
+ decode_params)
169
+
170
+ test_data = self.evaluate_test_data(
171
+ test_utils.TestData(x, encoded_x, decoded_x))
172
+ assert test_data.x.shape == (0,)
173
+ assert test_data.encoded_x[stage.ENCODED_VALUES_KEY].shape == (0,)
174
+ assert test_data.decoded_x.shape == (0,)
175
+
176
+ @parameterized.parameters([tf.bool, tf.float32])
177
+ def test_encode_unsupported_type_raises(self, dtype):
178
+ stage = self.default_encoding_stage()
179
+ with self.assertRaisesRegexp(TypeError, 'Unsupported input type'):
180
+ self.run_one_to_many_encode_decode(
181
+ stage, lambda: tf.cast(self.default_input(), dtype))
182
+
183
+ def test_encode_unsupported_input_shape_raises(self):
184
+ x = tf.random.uniform((3, 4), maxval=10, dtype=tf.int32)
185
+ stage = self.default_encoding_stage()
186
+ params, _ = stage.get_params()
187
+ with self.assertRaisesRegexp(ValueError, 'Number of dimensions must be 1'):
188
+ stage.encode(x, params)
189
+
190
+
191
+ if __name__ == '__main__':
192
+ tf.test.main()
mnist_cnn.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # pylint: disable=missing-docstring
16
+ """Train a simple convnet on the MNIST dataset."""
17
+ from __future__ import print_function
18
+
19
+ from absl import app as absl_app
20
+ from absl import flags
21
+ import tensorflow as tf
22
+
23
+ from tensorflow_model_optimization.python.core.keras.compat import keras
24
+ from tensorflow_model_optimization.python.core.sparsity.keras import prune
25
+ from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
26
+ from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
27
+
28
+
29
+ PolynomialDecay = pruning_schedule.PolynomialDecay
30
+ l = keras.layers
31
+
32
+ FLAGS = flags.FLAGS
33
+
34
+ batch_size = 128
35
+ num_classes = 10
36
+ epochs = 12
37
+
38
+ flags.DEFINE_string('output_dir', '/tmp/mnist_train/',
39
+ 'Output directory to hold tensorboard events')
40
+
41
+
42
+ def build_sequential_model(input_shape):
43
+ return keras.Sequential([
44
+ l.Conv2D(
45
+ 32, 5, padding='same', activation='relu', input_shape=input_shape
46
+ ),
47
+ l.MaxPooling2D((2, 2), (2, 2), padding='same'),
48
+ l.BatchNormalization(),
49
+ l.Conv2D(64, 5, padding='same', activation='relu'),
50
+ l.MaxPooling2D((2, 2), (2, 2), padding='same'),
51
+ l.Flatten(),
52
+ l.Dense(1024, activation='relu'),
53
+ l.Dropout(0.4),
54
+ l.Dense(num_classes, activation='softmax'),
55
+ ])
56
+
57
+
58
+ def build_functional_model(input_shape):
59
+ inp = keras.Input(shape=input_shape)
60
+ x = l.Conv2D(32, 5, padding='same', activation='relu')(inp)
61
+ x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
62
+ x = l.BatchNormalization()(x)
63
+ x = l.Conv2D(64, 5, padding='same', activation='relu')(x)
64
+ x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
65
+ x = l.Flatten()(x)
66
+ x = l.Dense(1024, activation='relu')(x)
67
+ x = l.Dropout(0.4)(x)
68
+ out = l.Dense(num_classes, activation='softmax')(x)
69
+
70
+ return keras.models.Model([inp], [out])
71
+
72
+
73
+ def build_layerwise_model(input_shape, **pruning_params):
74
+ return keras.Sequential([
75
+ prune.prune_low_magnitude(
76
+ l.Conv2D(32, 5, padding='same', activation='relu'),
77
+ input_shape=input_shape,
78
+ **pruning_params
79
+ ),
80
+ l.MaxPooling2D((2, 2), (2, 2), padding='same'),
81
+ l.BatchNormalization(),
82
+ prune.prune_low_magnitude(
83
+ l.Conv2D(64, 5, padding='same', activation='relu'), **pruning_params
84
+ ),
85
+ l.MaxPooling2D((2, 2), (2, 2), padding='same'),
86
+ l.Flatten(),
87
+ prune.prune_low_magnitude(
88
+ l.Dense(1024, activation='relu'), **pruning_params
89
+ ),
90
+ l.Dropout(0.4),
91
+ prune.prune_low_magnitude(
92
+ l.Dense(num_classes, activation='softmax'), **pruning_params
93
+ ),
94
+ ])
95
+
96
+
97
+ def train_and_save(models, x_train, y_train, x_test, y_test):
98
+ for model in models:
99
+ model.compile(
100
+ loss=keras.losses.categorical_crossentropy,
101
+ optimizer='adam',
102
+ metrics=['accuracy'],
103
+ )
104
+
105
+ # Print the model summary.
106
+ model.summary()
107
+
108
+ # Add a pruning step callback to peg the pruning step to the optimizer's
109
+ # step. Also add a callback to add pruning summaries to tensorboard
110
+ callbacks = [
111
+ pruning_callbacks.UpdatePruningStep(),
112
+ pruning_callbacks.PruningSummaries(log_dir=FLAGS.output_dir)
113
+ ]
114
+
115
+ model.fit(
116
+ x_train,
117
+ y_train,
118
+ batch_size=batch_size,
119
+ epochs=epochs,
120
+ verbose=1,
121
+ callbacks=callbacks,
122
+ validation_data=(x_test, y_test))
123
+ score = model.evaluate(x_test, y_test, verbose=0)
124
+ print('Test loss:', score[0])
125
+ print('Test accuracy:', score[1])
126
+
127
+ # Export and import the model. Check that accuracy persists.
128
+ saved_model_dir = '/tmp/saved_model'
129
+ print('Saving model to: ', saved_model_dir)
130
+ keras.models.save_model(model, saved_model_dir, save_format='tf')
131
+ print('Loading model from: ', saved_model_dir)
132
+ loaded_model = keras.models.load_model(saved_model_dir)
133
+
134
+ score = loaded_model.evaluate(x_test, y_test, verbose=0)
135
+ print('Test loss:', score[0])
136
+ print('Test accuracy:', score[1])
137
+
138
+
139
+ def main(unused_argv):
140
+ # input image dimensions
141
+ img_rows, img_cols = 28, 28
142
+
143
+ # the data, shuffled and split between train and test sets
144
+ (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
145
+
146
+ if keras.backend.image_data_format() == 'channels_first':
147
+ x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
148
+ x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
149
+ input_shape = (1, img_rows, img_cols)
150
+ else:
151
+ x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
152
+ x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
153
+ input_shape = (img_rows, img_cols, 1)
154
+
155
+ x_train = x_train.astype('float32')
156
+ x_test = x_test.astype('float32')
157
+ x_train /= 255
158
+ x_test /= 255
159
+ print('x_train shape:', x_train.shape)
160
+ print(x_train.shape[0], 'train samples')
161
+ print(x_test.shape[0], 'test samples')
162
+
163
+ # convert class vectors to binary class matrices
164
+ y_train = keras.utils.to_categorical(y_train, num_classes)
165
+ y_test = keras.utils.to_categorical(y_test, num_classes)
166
+
167
+ pruning_params = {
168
+ 'pruning_schedule':
169
+ PolynomialDecay(
170
+ initial_sparsity=0.1,
171
+ final_sparsity=0.75,
172
+ begin_step=1000,
173
+ end_step=5000,
174
+ frequency=100)
175
+ }
176
+
177
+ layerwise_model = build_layerwise_model(input_shape, **pruning_params)
178
+ sequential_model = build_sequential_model(input_shape)
179
+ sequential_model = prune.prune_low_magnitude(
180
+ sequential_model, **pruning_params)
181
+ functional_model = build_functional_model(input_shape)
182
+ functional_model = prune.prune_low_magnitude(
183
+ functional_model, **pruning_params)
184
+
185
+ models = [layerwise_model, sequential_model, functional_model]
186
+ train_and_save(models, x_train, y_train, x_test, y_test)
187
+
188
+
189
+ if __name__ == '__main__':
190
+ absl_app.run(main)
mnist_e2e_sparsity2x4.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # pylint: disable=missing-docstring,protected-access
16
+ """Train a simple convnet on the MNIST dataset with sparsity 2x4.
17
+
18
+ It is based on mnist_e2e.py
19
+ """
20
+ from __future__ import print_function
21
+
22
+ from absl import app as absl_app
23
+ import tensorflow as tf
24
+
25
+ from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
26
+ from tensorflow_model_optimization.python.core.keras.compat import keras
27
+ from tensorflow_model_optimization.python.core.sparsity.keras import prune
28
+ from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
29
+ from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
30
+ from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils
31
+ from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
32
+
33
+
34
+ ConstantSparsity = pruning_schedule.ConstantSparsity
35
+ l = keras.layers
36
+
37
+ tf.random.set_seed(42)
38
+
39
+ batch_size = 128
40
+ num_classes = 10
41
+ epochs = 1
42
+
43
+ PRUNABLE_2x4_LAYERS = (keras.layers.Conv2D, keras.layers.Dense)
44
+
45
+
46
+ def check_model_sparsity_2x4(model):
47
+ for layer in model.layers:
48
+ if isinstance(layer, pruning_wrapper.PruneLowMagnitude) and isinstance(
49
+ layer.layer, PRUNABLE_2x4_LAYERS):
50
+ for weight in layer.layer.get_prunable_weights():
51
+ if not pruning_utils.is_pruned_m_by_n(weight):
52
+ return False
53
+ return True
54
+
55
+
56
+ def build_layerwise_model(input_shape, **pruning_params):
57
+ return keras.Sequential([
58
+ prune.prune_low_magnitude(
59
+ l.Conv2D(
60
+ 32, 5, padding='same', activation='relu', input_shape=input_shape
61
+ ),
62
+ **pruning_params
63
+ ),
64
+ l.MaxPooling2D((2, 2), (2, 2), padding='same'),
65
+ prune.prune_low_magnitude(
66
+ l.Conv2D(64, 5, padding='same'), **pruning_params
67
+ ),
68
+ l.BatchNormalization(),
69
+ l.ReLU(),
70
+ l.MaxPooling2D((2, 2), (2, 2), padding='same'),
71
+ l.Flatten(),
72
+ prune.prune_low_magnitude(
73
+ l.Dense(1024, activation='relu'), **pruning_params
74
+ ),
75
+ l.Dropout(0.4),
76
+ l.Dense(num_classes, activation='softmax'),
77
+ ])
78
+
79
+
80
+ def train(model, x_train, y_train, x_test, y_test):
81
+ model.compile(
82
+ loss=keras.losses.categorical_crossentropy,
83
+ optimizer='adam',
84
+ metrics=['accuracy'],
85
+ )
86
+ model.run_eagerly = True
87
+
88
+ # Print the model summary.
89
+ model.summary()
90
+
91
+ # Add a pruning step callback to peg the pruning step to the optimizer's
92
+ # step. Also add a callback to add pruning summaries to tensorboard
93
+ callbacks = [
94
+ pruning_callbacks.UpdatePruningStep(),
95
+ pruning_callbacks.PruningSummaries(log_dir='/tmp/logs')
96
+ ]
97
+
98
+ model.fit(
99
+ x_train,
100
+ y_train,
101
+ batch_size=batch_size,
102
+ epochs=epochs,
103
+ verbose=1,
104
+ callbacks=callbacks,
105
+ validation_data=(x_test, y_test))
106
+ score = model.evaluate(x_test, y_test, verbose=0)
107
+ print('Test loss:', score[0])
108
+ print('Test accuracy:', score[1])
109
+
110
+ # Check sparsity 2x4 type before stripping pruning
111
+ is_pruned_2x4 = check_model_sparsity_2x4(model)
112
+ print('Pass the check for sparsity 2x4: ', is_pruned_2x4)
113
+
114
+ model = prune.strip_pruning(model)
115
+ return model
116
+
117
+
118
+ def main(unused_argv):
119
+ ##############################################################################
120
+ # Prepare training and testing data
121
+ ##############################################################################
122
+ (x_train, y_train), (
123
+ x_test,
124
+ y_test), input_shape = keras_test_utils.get_preprocessed_mnist_data()
125
+
126
+ ##############################################################################
127
+ # Train a model with sparsity 2x4.
128
+ ##############################################################################
129
+ pruning_params = {
130
+ 'pruning_schedule': ConstantSparsity(0.5, begin_step=0, frequency=100),
131
+ 'sparsity_m_by_n': (2, 4),
132
+ }
133
+
134
+ model = build_layerwise_model(input_shape, **pruning_params)
135
+ pruned_model = train(model, x_train, y_train, x_test, y_test)
136
+
137
+ # Write a model that has been pruned with 2x4 sparsity.
138
+ converter = tf.lite.TFLiteConverter.from_keras_model(pruned_model)
139
+ tflite_model = converter.convert()
140
+
141
+ tflite_model_path = '/tmp/mnist_2x4.tflite'
142
+ print('model is saved to {}'.format(tflite_model_path))
143
+ with open(tflite_model_path, 'wb') as f:
144
+ f.write(tflite_model)
145
+
146
+ print('evaluate pruned model: ')
147
+ print(keras_test_utils.eval_mnist_tflite(model_content=tflite_model))
148
+ # the accuracy of 2:4 pruning model is 0.9866
149
+ # the accuracy of unstructured model with 50% is 0.9863
150
+
151
+
152
+ if __name__ == '__main__':
153
+ absl_app.run(main)
periodical_update_and_scheduling_test.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for when the training and inference graphs are the same."""
16
+
17
+ import os
18
+ import tempfile
19
+
20
+ import tensorflow as tf
21
+
22
+ from tensorflow_model_optimization.python.core.common.keras.compression.algorithms import periodical_update_and_scheduling as svd
23
+ from tensorflow_model_optimization.python.core.keras.compat import keras
24
+ from tensorflow_model_optimization.python.core.keras.testing import test_utils_mnist
25
+
26
+
27
+ def _build_model():
28
+ i = keras.layers.Input(shape=(28, 28), name='input')
29
+ x = keras.layers.Reshape((28, 28, 1))(i)
30
+ x = keras.layers.Conv2D(
31
+ 20, 5, activation='relu', padding='valid', name='conv1'
32
+ )(x)
33
+ x = keras.layers.MaxPool2D(2, 2)(x)
34
+ x = keras.layers.Conv2D(
35
+ 50, 5, activation='relu', padding='valid', name='conv2'
36
+ )(x)
37
+ x = keras.layers.MaxPool2D(2, 2)(x)
38
+ x = keras.layers.Flatten()(x)
39
+ x = keras.layers.Dense(500, activation='relu', name='fc1')(x)
40
+ output = keras.layers.Dense(10, name='fc2')(x)
41
+
42
+ model = keras.Model(inputs=[i], outputs=[output])
43
+ return model
44
+
45
+
46
+ def _get_dataset():
47
+ mnist = keras.datasets.mnist
48
+ (x_train, y_train), (x_test, y_test) = mnist.load_data()
49
+ x_train, x_test = x_train / 255.0, x_test / 255.0
50
+ # Use subset of 60000 examples to keep unit test speed fast.
51
+ x_train = x_train[0:1000]
52
+ y_train = y_train[0:1000]
53
+ return (x_train, y_train), (x_test, y_test)
54
+
55
+
56
+ def _train_model(model):
57
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
58
+
59
+ model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
60
+
61
+ (x_train, y_train), _ = _get_dataset()
62
+
63
+ model.fit(x_train, y_train, epochs=1)
64
+
65
+
66
+ def _save_as_saved_model(model):
67
+ saved_model_dir = tempfile.mkdtemp()
68
+ model.save(saved_model_dir)
69
+ return saved_model_dir
70
+
71
+
72
+ # TODO(tfmot): reuse existing test utilities.
73
+ def _convert_to_tflite(saved_model_dir):
74
+ _, tflite_file = tempfile.mkstemp()
75
+
76
+ converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
77
+ tflite_model = converter.convert()
78
+
79
+ with open(tflite_file, 'wb') as f:
80
+ f.write(tflite_model)
81
+
82
+ return tflite_file
83
+
84
+
85
+ def _get_directory_size_in_bytes(directory):
86
+ total = 0
87
+ try:
88
+ for entry in os.scandir(directory):
89
+ if entry.is_file():
90
+ # if it's a file, use stat() function
91
+ total += entry.stat().st_size
92
+ elif entry.is_dir():
93
+ # if it's a directory, recursively call this function
94
+ total += _get_directory_size_in_bytes(entry.path)
95
+ except NotADirectoryError:
96
+ # if `directory` isn't a directory, get the file size then
97
+ return os.path.getsize(directory)
98
+ except PermissionError:
99
+ # if for whatever reason we can't open the folder, return 0
100
+ return 0
101
+ return total
102
+
103
+
104
+ class FunctionalTest(tf.test.TestCase):
105
+
106
+ # TODO(tfmot): can simplify to single layer test that checks exact
107
+ # dimensions of weights.
108
+ def testSVD_ReducesSavedModelSize(self):
109
+ model = _build_model()
110
+
111
+ original_saved_model_dir = _save_as_saved_model(model)
112
+
113
+ algorithm = svd.SVD(rank=16, update_freq=1, warmup_step=10)
114
+ training_model = algorithm.optimize_model(model)
115
+ compressed_model = algorithm.compress_model(training_model)
116
+
117
+ saved_model_dir = _save_as_saved_model(compressed_model)
118
+
119
+ original_size = _get_directory_size_in_bytes(original_saved_model_dir)
120
+ compressed_size = _get_directory_size_in_bytes(saved_model_dir)
121
+
122
+ self.assertLess(compressed_size, original_size / 3)
123
+
124
+ def testSVD_HasReasonableAccuracy_TF(self):
125
+ model = _build_model()
126
+
127
+ algorithm = svd.SVD(rank=16, update_freq=1, warmup_step=10)
128
+ training_model = algorithm.optimize_model(model)
129
+
130
+ _train_model(training_model)
131
+
132
+ compressed_model = algorithm.compress_model(training_model)
133
+
134
+ _, (x_test, y_test) = _get_dataset()
135
+
136
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
137
+
138
+ compressed_model.compile(
139
+ optimizer='adam', loss=loss_fn, metrics=['accuracy'])
140
+
141
+ results = compressed_model.evaluate(x_test, y_test)
142
+
143
+ self.assertGreater(results[1], 0.60)
144
+
145
+ def testSVD_ReducesTFLiteModelSize(self):
146
+ model = _build_model()
147
+
148
+ original_saved_model_dir = _save_as_saved_model(model)
149
+ original_tflite_file = _convert_to_tflite(original_saved_model_dir)
150
+
151
+ algorithm = svd.SVD(rank=16, update_freq=1, warmup_step=10)
152
+ training_model = algorithm.optimize_model(model)
153
+ compressed_model = algorithm.compress_model(training_model)
154
+
155
+ saved_model_dir = _save_as_saved_model(compressed_model)
156
+ compressed_tflite_file = _convert_to_tflite(saved_model_dir)
157
+
158
+ original_size = os.path.getsize(original_tflite_file)
159
+ compressed_size = os.path.getsize(compressed_tflite_file)
160
+
161
+ self.assertLess(compressed_size, original_size / 6)
162
+
163
+ def testSVD_HasReasonableAccuracy_TFLite(self):
164
+ model = _build_model()
165
+
166
+ algorithm = svd.SVD(rank=16, update_freq=1, warmup_step=10)
167
+ training_model = algorithm.optimize_model(model)
168
+
169
+ _train_model(training_model)
170
+
171
+ compressed_model = algorithm.compress_model(training_model)
172
+
173
+ saved_model_dir = _save_as_saved_model(compressed_model)
174
+ compressed_tflite_file = _convert_to_tflite(saved_model_dir)
175
+
176
+ accuracy = test_utils_mnist.eval_tflite(compressed_tflite_file)
177
+
178
+ self.assertGreater(accuracy, 0.60)
179
+
180
+ # TODO(tfmot): can simplify to single layer test.
181
+ def testSVD_BreaksDownLayerWeights(self):
182
+ model = _build_model()
183
+
184
+ first_conv_layer = model.layers[2]
185
+ self.assertLen(first_conv_layer.weights, 2)
186
+
187
+ algorithm = svd.SVD(rank=16, update_freq=1, warmup_step=10)
188
+ training_model = algorithm.optimize_model(model)
189
+ compressed_model = algorithm.compress_model(training_model)
190
+
191
+ first_conv_layer = compressed_model.layers[2]
192
+
193
+ self.assertLen(first_conv_layer.weights, 3)
194
+
195
+ # TODO(tfmot): can simplify to single layer test.
196
+ def testSVD_PreservesPretrainedWeights(self):
197
+ i = keras.layers.Input(shape=(2), name='input')
198
+ output = keras.layers.Dense(3, name='fc1')(i)
199
+ model = keras.Model(inputs=[i], outputs=[output])
200
+
201
+ dense_layer_weights = model.layers[1].get_weights()
202
+
203
+ algorithm = svd.SVD(rank=1, update_freq=1, warmup_step=10)
204
+ training_model = algorithm.optimize_model(model)
205
+
206
+ dense_layer_training_weights = training_model.layers[1].get_weights()
207
+
208
+ # kernel
209
+ algorithm.weight_reprs = []
210
+ algorithm.init_training_weights(dense_layer_weights[0])
211
+ w1_repr, w2_repr = algorithm.weight_reprs
212
+ assert (w1_repr.kwargs['initializer'](None) == \
213
+ dense_layer_training_weights[0]).numpy().all()
214
+ assert (w2_repr.kwargs['initializer'](None) == \
215
+ dense_layer_training_weights[1]).numpy().all()
216
+
217
+ # bias
218
+ assert (dense_layer_weights[1] == dense_layer_training_weights[2]).all()
219
+
220
+
221
+ if __name__ == '__main__':
222
+ tf.test.main()
prune_preserve_quantize_registry.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Registry responsible for built-in keras classes."""
16
+
17
+ import tensorflow as tf
18
+
19
+ from tensorflow_model_optimization.python.core.keras.compat import keras
20
+ from tensorflow_model_optimization.python.core.quantization.keras import quant_ops
21
+ from tensorflow_model_optimization.python.core.quantization.keras import quantizers
22
+ from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import (
23
+ default_8bit_quantize_registry,)
24
+ from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import (
25
+ default_8bit_quantizers,)
26
+
27
+
28
+ layers = keras.layers
29
+
30
+
31
+ class _PrunePreserveInfo(object):
32
+ """PrunePreserveInfo."""
33
+
34
+ def __init__(self, weight_attrs, quantize_config_attrs):
35
+ """Initializes PrunePreserveInfo.
36
+
37
+ Args:
38
+ weight_attrs: list of sparsity preservable weight attributes of layer.
39
+ quantize_config_attrs: list of quantization configuration class name.
40
+ """
41
+ self.weight_attrs = weight_attrs
42
+ self.quantize_config_attrs = quantize_config_attrs
43
+
44
+
45
+ class PrunePreserveQuantizeRegistry():
46
+ """PrunePreserveQuantizeRegistry responsible for built-in keras layers."""
47
+
48
+ # The keys represent built-in keras layers; the first values represent the
49
+ # the variables within the layers which hold the kernel weights, second
50
+ # values represent the class name of quantization configuration for layers.
51
+ # This decide the weights of layers with quantization configurations are
52
+ # sparsity preservable.
53
+ _LAYERS_CONFIG_MAP = {
54
+ layers.Conv2D:
55
+ _PrunePreserveInfo(['kernel'], ['Default8BitConvQuantizeConfig']),
56
+ layers.Dense:
57
+ _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
58
+
59
+ # DepthwiseConv2D is supported with 8bit qat, but not with prune,
60
+ # thus for DepthwiseConv2D PQAT, weights sparsity preserve is disabled.
61
+ layers.DepthwiseConv2D:
62
+ _PrunePreserveInfo(['depthwise_kernel'], ['Default8BitQuantizeConfig']),
63
+
64
+ # layers that supported with prune, but not yet with QAT
65
+ # layers.Conv1D:
66
+ # _PrunePreserveInfo(['kernel'], []),
67
+ # layers.Conv2DTranspose:
68
+ # _PrunePreserveInfo(['kernel'], []),
69
+ # layers.Conv3D:
70
+ # _PrunePreserveInfo(['kernel'], []),
71
+ # layers.Conv3DTranspose:
72
+ # _PrunePreserveInfo(['kernel'], []),
73
+ # layers.LocallyConnected1D:
74
+ # _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
75
+ # layers.LocallyConnected2D:
76
+ # _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
77
+
78
+ # SeparableConv need verify from 8bit qat
79
+ # layers.SeparableConv1D:
80
+ # _PrunePreserveInfo(['pointwise_kernel'], \
81
+ # ['Default8BitConvQuantizeConfig']),
82
+ # layers.SeparableConv2D:
83
+ # _PrunePreserveInfo(['pointwise_kernel'], \
84
+ # ['Default8BitConvQuantizeConfig']),
85
+
86
+ # Embedding need verify from 8bit qat
87
+ # layers.Embedding: _PrunePreserveInfo(['embeddings'], []),
88
+ }
89
+
90
+ _DISABLE_PRUNE_PRESERVE = frozenset({
91
+ layers.DepthwiseConv2D,
92
+ })
93
+
94
+ def __init__(self):
95
+
96
+ self._config_quantizer_map = {
97
+ 'Default8BitQuantizeConfig':
98
+ PrunePreserveDefault8BitWeightsQuantizer(),
99
+ 'Default8BitConvQuantizeConfig':
100
+ PrunePreserveDefault8BitConvWeightsQuantizer(),
101
+ }
102
+
103
+ @classmethod
104
+ def _no_trainable_weights(cls, layer):
105
+ """Returns whether this layer has trainable weights.
106
+
107
+ Args:
108
+ layer: The layer to check for trainable weights.
109
+
110
+ Returns:
111
+ True/False whether the layer has trainable weights.
112
+ """
113
+ return not layer.trainable_weights
114
+
115
+ @classmethod
116
+ def _disable_prune_preserve(cls, layer):
117
+ """Returns whether disable this layer for prune preserve.
118
+
119
+ Args:
120
+ layer: The layer to check for disable.
121
+
122
+ Returns:
123
+ True/False whether disable this layer for prune preserve.
124
+ """
125
+
126
+ return layer.__class__ in cls._DISABLE_PRUNE_PRESERVE
127
+
128
+ @classmethod
129
+ def supports(cls, layer):
130
+ """Returns whether the registry supports this layer type.
131
+
132
+ Args:
133
+ layer: The layer to check for support.
134
+
135
+ Returns:
136
+ True/False whether the layer type is supported.
137
+ """
138
+
139
+ # layers without trainable weights are considered supported,
140
+ # e.g., ReLU, Softmax, and AveragePooling2D.
141
+ if cls._no_trainable_weights(layer):
142
+ return True
143
+
144
+ if layer.__class__ in cls._LAYERS_CONFIG_MAP:
145
+ return True
146
+
147
+ return False
148
+
149
+ @classmethod
150
+ def _weight_names(cls, layer):
151
+ """Gets the weight names."""
152
+ if cls._no_trainable_weights(layer):
153
+ return []
154
+
155
+ return cls._LAYERS_CONFIG_MAP[layer.__class__].weight_attrs
156
+
157
+ @classmethod
158
+ def get_sparsity_preservable_weights(cls, layer):
159
+ """Gets sparsity preservable weights from keras layer.
160
+
161
+ Args:
162
+ layer: instance of keras layer
163
+
164
+ Returns:
165
+ List of sparsity preservable weights
166
+ """
167
+ return [getattr(layer, weight) for weight in cls._weight_names(layer)]
168
+
169
+ @classmethod
170
+ def get_suppport_quantize_config_names(cls, layer):
171
+ """Gets class name of supported quantize config for layer.
172
+
173
+ Args:
174
+ layer: instance of keras layer
175
+
176
+ Returns:
177
+ List of supported quantize config class name.
178
+ """
179
+
180
+ # layers without trainable weights don't need quantize_config for pqat
181
+ if cls._no_trainable_weights(layer):
182
+ return []
183
+
184
+ return cls._LAYERS_CONFIG_MAP[layer.__class__].quantize_config_attrs
185
+
186
+ def apply_sparsity_preserve_quantize_config(self, layer, quantize_config):
187
+ """Applies weights sparsity preservation.
188
+
189
+ Args:
190
+ layer: The layer to check for support.
191
+ quantize_config: quantization config to check for support,
192
+ apply sparsity preservation to pruned weights
193
+ Raises:
194
+ ValueError when layer is supported does not have quantization config.
195
+ Returns:
196
+ Returns quantize_config with addon sparsity preserve weight_quantizer.
197
+ """
198
+ if self.supports(layer):
199
+ if (self._no_trainable_weights(layer) or
200
+ self._disable_prune_preserve(layer)):
201
+ return quantize_config
202
+ if (quantize_config.__class__.__name__
203
+ in self._LAYERS_CONFIG_MAP[layer.__class__].quantize_config_attrs):
204
+ quantize_config.weight_quantizer = self._config_quantizer_map[
205
+ quantize_config.__class__.__name__]
206
+ else:
207
+ raise ValueError('Configuration {} is not supported for Layer {}.'
208
+ .format(str(quantize_config.__class__.__name__),
209
+ str(layer.__class__.__name__)))
210
+ else:
211
+ raise ValueError('Layer {} is not supported.'.format(
212
+ str(layer.__class__.__name__)))
213
+
214
+ return quantize_config
215
+
216
+
217
+ class Default8bitPrunePreserveQuantizeRegistry(PrunePreserveQuantizeRegistry):
218
+ """Default 8 bit PrunePreserveQuantizeRegistry."""
219
+
220
+ def get_quantize_config(self, layer):
221
+ """Returns the quantization config with addon sparsity.
222
+
223
+ Args:
224
+ layer: input layer to return quantize config for.
225
+
226
+ Returns:
227
+ Returns the quantization config with sparsity preserve weight_quantizer.
228
+ """
229
+ quantize_config = (default_8bit_quantize_registry
230
+ .Default8BitQuantizeRegistry()
231
+ .get_quantize_config(layer))
232
+ prune_aware_quantize_config = self.apply_sparsity_preserve_quantize_config(
233
+ layer, quantize_config)
234
+
235
+ return prune_aware_quantize_config
236
+
237
+
238
+ class PrunePreserveDefaultWeightsQuantizer(quantizers.LastValueQuantizer):
239
+ """Quantize weights while preserve sparsity."""
240
+
241
+ def __init__(self, num_bits, per_axis, symmetric, narrow_range):
242
+ """Initializes PrunePreserveDefaultWeightsQuantizer.
243
+
244
+ Args:
245
+ num_bits: Number of bits for quantization
246
+ per_axis: Whether to apply per_axis quantization. The last dimension is
247
+ used as the axis.
248
+ symmetric: If true, use symmetric quantization limits instead of training
249
+ the minimum and maximum of each quantization range separately.
250
+ narrow_range: In case of 8 bits, narrow_range nudges the quantized range
251
+ to be [-127, 127] instead of [-128, 127]. This ensures symmetric range
252
+ has 0 as the centre.
253
+ """
254
+ quantizers.LastValueQuantizer.__init__(self, num_bits, per_axis, symmetric,
255
+ narrow_range)
256
+
257
+ def _build_sparsity_mask(self, name, layer):
258
+ weights = getattr(layer.layer, name)
259
+ sparsity_mask = tf.math.divide_no_nan(weights, weights)
260
+
261
+ return {'sparsity_mask': sparsity_mask}
262
+
263
+ def build(self, tensor_shape, name, layer):
264
+ """Constructs mask to preserve weights sparsity.
265
+
266
+ Args:
267
+ tensor_shape: Shape of weights which needs to be quantized.
268
+ name: Name of weights in layer.
269
+ layer: quantization wrapped keras layer.
270
+
271
+ Returns:
272
+ Dictionary of constructed sparsity mask and
273
+ quantization params, the dictionary will be passed
274
+ to __call__ function.
275
+ """
276
+ result = self._build_sparsity_mask(name, layer)
277
+ result.update(
278
+ super(PrunePreserveDefaultWeightsQuantizer,
279
+ self).build(tensor_shape, name, layer))
280
+ return result
281
+
282
+ def __call__(self, inputs, training, weights, **kwargs):
283
+ """Applies sparsity preserved quantization to the input tensor.
284
+
285
+ Args:
286
+ inputs: Input tensor (layer's weights) to be quantized.
287
+ training: Whether the graph is currently training.
288
+ weights: Dictionary of weights (params) the quantizer can use to
289
+ quantize the tensor (layer's weights). This contains the weights
290
+ created in the `build` function.
291
+ **kwargs: Additional variables which may be passed to the quantizer.
292
+
293
+ Returns:
294
+ quantized tensor.
295
+ """
296
+
297
+ prune_preserve_inputs = tf.multiply(inputs, weights['sparsity_mask'])
298
+
299
+ return quant_ops.LastValueQuantize(
300
+ prune_preserve_inputs,
301
+ weights['min_var'],
302
+ weights['max_var'],
303
+ is_training=training,
304
+ num_bits=self.num_bits,
305
+ per_channel=self.per_axis,
306
+ symmetric=self.symmetric,
307
+ narrow_range=self.narrow_range,
308
+ )
309
+
310
+
311
+ class PrunePreserveDefault8BitWeightsQuantizer(
312
+ PrunePreserveDefaultWeightsQuantizer):
313
+ """PrunePreserveWeightsQuantizer for default 8bit weights."""
314
+
315
+ def __init__(self):
316
+ super(PrunePreserveDefault8BitWeightsQuantizer,
317
+ self).__init__(num_bits=8,
318
+ per_axis=False,
319
+ symmetric=True,
320
+ narrow_range=True)
321
+
322
+
323
+ class PrunePreserveDefault8BitConvWeightsQuantizer(
324
+ PrunePreserveDefaultWeightsQuantizer,
325
+ default_8bit_quantizers.Default8BitConvWeightsQuantizer,):
326
+ """PrunePreserveWeightsQuantizer for default 8bit Conv2D/DepthwiseConv2D weights."""
327
+
328
+ # pylint: disable=super-init-not-called
329
+ def __init__(self):
330
+ # Skip PrunePreserveDefaultWeightsQuantizer since they have the same super.
331
+ default_8bit_quantizers.Default8BitConvWeightsQuantizer.__init__(self)
332
+
333
+ def build(self, tensor_shape, name, layer):
334
+ result = PrunePreserveDefaultWeightsQuantizer._build_sparsity_mask(
335
+ self, name, layer)
336
+ result.update(
337
+ default_8bit_quantizers.Default8BitConvWeightsQuantizer.build(
338
+ self, tensor_shape, name, layer))
339
+ return result
readme.txt ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ixl iWARP FreeBSD* driver for Intel(R) Ethernet Connection X722
2
+ ================================================================
3
+ July 9, 2019
4
+
5
+ Contents
6
+ ========
7
+
8
+ - Prerequisites
9
+ - Building and Installation
10
+ - Testing
11
+ - Configuration
12
+ - Interoperability
13
+ - Known Issues
14
+
15
+
16
+ Prerequisites
17
+ =============
18
+
19
+ - FreeBSD version 11.2
20
+ - Kernel configuration:
21
+ Please add the following kernel configuration options:
22
+ include GENERIC
23
+ options COMPAT_LINUXKPI
24
+ options IPOIB_CM
25
+ options IXL_IW
26
+
27
+ nodevice ixl
28
+ nodevice iavf
29
+ Note: IXL_IW is required for FreeBSD-CURRENT branch.
30
+ - For the iw_ixl driver to work, an if_ixl driver with iwarp interface
31
+ is required. The interface is available in if_ixl version 1.7.12 or later.
32
+ It should be enabled prior to usage, as the setting is switched off by
33
+ default. To enable iwarp compatibility, add
34
+ hw.ixl.enable_iwarp=1
35
+ to
36
+ /boot/loader.conf
37
+
38
+ The lan driver can be downloaded from
39
+ https://downloadcenter.intel.com/download/25160/Ethernet-Intel-Network-Adapter-D
40
+ river-for-PCIe-40-Gigabit-Ethernet-Network-Connection-under-FreeBSD
41
+ Or search on downloadcenter.intel.com using '40 Gigabit Ethernet Network
42
+ Connection under FreeBSD'. Newer OS releases contain the if_ixl driver in
43
+ the ixl driver version 1.7.12-k or later.
44
+
45
+ There are some known issues with the interface on if_ixl-1.7.12. Please
46
+ use version 1.7.13 or later.
47
+
48
+ - fastreg memory mode in krping needs a patch applied to krping.
49
+ Refer to the 'Testing' and 'Known Issues' sections for details.
50
+
51
+
52
+ Building and Installation
53
+ =========================
54
+
55
+ 1. Untar ixl-<version>.tar.gz and iw_ixl-<version>.tar.gz
56
+
57
+ # tar -xf ixl-<version>.tar.gz
58
+ # tar -xf iw_ixl-<version>.tar.gz
59
+
60
+ 2. Install the if_ixl driver:
61
+
62
+ # cd ixl-<version>/src directory
63
+ # make
64
+ # make install
65
+
66
+ 3. Install the iw_ixl driver:
67
+
68
+ # cd iw_ixl-<version>/src
69
+ # make clean
70
+ # make IXL_DIR=$PATH_TO_IXL/ixl-<version>/src
71
+ # make install
72
+
73
+ 4. Install the man page for the iw_ixl driver by copying the iw_ixl.4.gz file
74
+ to the directory where manual pages are held on your system. For instance:
75
+
76
+ # cp iw_ixl-<version>/doc/iw_ixl.4.gz /usr/share/man/man4/
77
+
78
+ For in-tree driver if_ixl-1.7.12-k or later, it is sufficient to follow
79
+ the instruction from point 3 but ensure the correct path to if_ixl source
80
+ folder is supplied. For instance:
81
+ IXL_DIR=/usr/src/sys/dev/ixl/
82
+
83
+
84
+ Testing
85
+ -------
86
+ 1. To load the iw_ixl driver, call:
87
+
88
+ # kldload iw_ixl
89
+
90
+ If if_ixl is not already loaded, the system will load it on its own.
91
+ Please remember to add
92
+ hw.ixl.enable_iwarp=1
93
+ to /boot/loader.conf file prior to if_ixl loading, to ensure the ixl
94
+ driver has the iwarp interface enabled.
95
+
96
+ 2. To validate the load of the driver, check:
97
+
98
+ # sysctl -a | grep infiniband
99
+
100
+ A number of sys.class.infiniband should appear, provided at least one
101
+ port of the X722 is up.
102
+
103
+ 3. The source code for krping software is provided with the kernel in
104
+ /usr/src/sys/contrib/rdma/krping/. To compile the software, change directory
105
+ to /usr/src/sys/modules/rdma/krping/ and invoke the following:
106
+
107
+ # make clean
108
+ # make
109
+ # make install
110
+
111
+ 4. Start krping server on one machine:
112
+
113
+ # echo size=64,count=1,port=6601,addr=100.0.0.189,server > /dev/krping
114
+ 5. Connect client from another machine:
115
+
116
+ # echo size=64,count=1,port=6601,addr=100.0.0.189,client > /dev/krping
117
+
118
+
119
+ Configuration
120
+ =============
121
+ The following sysctl options are visible:
122
+ - hw.iw_ixl.max_ceq
123
+ determines the maximum number of msix vectors available to the driver
124
+ for CEQ usage.
125
+ - hw.iw_ixl.debug
126
+ defines level of debug messages.
127
+ - hw.iw_ixl.mpa_version
128
+ shows the current MPA version used.
129
+
130
+ The max_ceq setting may be changed by adding:
131
+ hw.iw_ixl.max_ceq=$value
132
+ to /boot/loader.conf file. The final number of CEQ is evaluated depending
133
+ on the available msix vectors, number of cpu cores, and hardware limits.
134
+
135
+ If max_ceq=0, the value is ignored.
136
+
137
+ The debug setting may be changed either by adding:
138
+ hw.iw_ixl.debug=$value
139
+ to the /boot/loader.conf file or by calling
140
+ sysctl hw.iw_ixl.debug=$value
141
+
142
+ The mpa_version may be changed by adding:
143
+ hw.iw_ixl.mpa_version=$value
144
+ to the /boot/loader.conf file.
145
+
146
+
147
+ Interoperability
148
+ ================
149
+
150
+ To interoperate with Chelsio iWARP devices:
151
+
152
+ 1. Load the ixl driver with parameter mpa_version set to 1. Add the line:
153
+ hw.iw_ixl.mpa_version=1
154
+ to /boot/loader.conf
155
+
156
+ 2. Load Chelsio T4/T5 RDMA driver (iw_cxgb4) with parameter dack_mode set to 0.
157
+
158
+
159
+ Known Issues
160
+ ============
161
+
162
+ - Loopback is not supported.
163
+ - MTU changes are not supported.
164
+ - IPv6 is not supported.
165
+ - MW memory mode is not supported.
166
+ - MR memory mode supports only single buffer.
167
+ - The function ib_cq_resize is not supported.
168
+ - The max number of registered cq, qp, pd or mr reported by the device may
169
+ differ from the actual number of registrations achievable.
170
+ - A kernel crash may occur when trying to run krping without ensuring that the
171
+ two machines are able to ping each other.
172
+ - A kernel crash may occur when trying to load the iw_ixl driver when
173
+ hw.ixl.enable_iwarp=0 (fixed with if_ixl 1.7.13).
174
+ - A kernel crash may occur when loading the iw_ixl driver on a card that is
175
+ supported by if_ixl driver, but does not have iWARP capability (fixed with
176
+ if_ixl 1.7.13).
177
+ - Krping with fastreg memory mode will not work unless some changes are made
178
+ to krping. To work around the issue, modify the krping_rdma_rkey function
179
+ such that, in the case of FASTREG memory mode, the ib_post_send function
180
+ with &cd->invalidate_wr parameter is not called during the first run of
181
+ the function.
182
+
183
+
184
+ Support
185
+ =======
186
+ For general information, go to the Intel support website at:
187
+ http://www.intel.com/support/
188
+
189
+ If an issue is identified with the released source code on a supported kernel
190
+ with a supported adapter, email the specific information related to the issue
191
192
+
193
+
194
+ Copyright(c) 2017-2019 Intel Corporation.
195
+
196
+
197
+ Trademarks
198
+ ==========
199
+ Intel is a trademark or registered trademark of Intel Corporation or its
200
+ subsidiaries in the United States and/or other countries.
201
+
202
+ * Other names and brands may be claimed as the property of others.
203
+
204
+
same_training_and_inference_test.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for when the training and inference graphs are the same."""
16
+
17
+ import os
18
+ import tempfile
19
+
20
+ import tensorflow as tf
21
+
22
+ from tensorflow_model_optimization.python.core.common.keras.compression.algorithms import same_training_and_inference as svd
23
+ from tensorflow_model_optimization.python.core.keras.compat import keras
24
+ from tensorflow_model_optimization.python.core.keras.testing import test_utils_mnist
25
+
26
+
27
+ def _build_model():
28
+ i = keras.layers.Input(shape=(28, 28), name='input')
29
+ x = keras.layers.Reshape((28, 28, 1))(i)
30
+ x = keras.layers.Conv2D(
31
+ 20, 5, activation='relu', padding='valid', name='conv1'
32
+ )(x)
33
+ x = keras.layers.MaxPool2D(2, 2)(x)
34
+ x = keras.layers.Conv2D(
35
+ 50, 5, activation='relu', padding='valid', name='conv2'
36
+ )(x)
37
+ x = keras.layers.MaxPool2D(2, 2)(x)
38
+ x = keras.layers.Flatten()(x)
39
+ x = keras.layers.Dense(500, activation='relu', name='fc1')(x)
40
+ output = keras.layers.Dense(10, name='fc2')(x)
41
+
42
+ model = keras.Model(inputs=[i], outputs=[output])
43
+ return model
44
+
45
+
46
+ def _get_dataset():
47
+ mnist = keras.datasets.mnist
48
+ (x_train, y_train), (x_test, y_test) = mnist.load_data()
49
+ x_train, x_test = x_train / 255.0, x_test / 255.0
50
+ # Use subset of 60000 examples to keep unit test speed fast.
51
+ x_train = x_train[0:1000]
52
+ y_train = y_train[0:1000]
53
+ return (x_train, y_train), (x_test, y_test)
54
+
55
+
56
+ def _train_model(model):
57
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
58
+
59
+ model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
60
+
61
+ (x_train, y_train), _ = _get_dataset()
62
+
63
+ model.fit(x_train, y_train, epochs=1)
64
+
65
+
66
+ def _save_as_saved_model(model):
67
+ saved_model_dir = tempfile.mkdtemp()
68
+ model.save(saved_model_dir)
69
+ return saved_model_dir
70
+
71
+
72
+ # TODO(tfmot): reuse existing test utilities.
73
+ def _convert_to_tflite(saved_model_dir):
74
+ _, tflite_file = tempfile.mkstemp()
75
+
76
+ converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
77
+ tflite_model = converter.convert()
78
+
79
+ with open(tflite_file, 'wb') as f:
80
+ f.write(tflite_model)
81
+
82
+ return tflite_file
83
+
84
+
85
+ def _get_directory_size_in_bytes(directory):
86
+ total = 0
87
+ try:
88
+ for entry in os.scandir(directory):
89
+ if entry.is_file():
90
+ # if it's a file, use stat() function
91
+ total += entry.stat().st_size
92
+ elif entry.is_dir():
93
+ # if it's a directory, recursively call this function
94
+ total += _get_directory_size_in_bytes(entry.path)
95
+ except NotADirectoryError:
96
+ # if `directory` isn't a directory, get the file size then
97
+ return os.path.getsize(directory)
98
+ except PermissionError:
99
+ # if for whatever reason we can't open the folder, return 0
100
+ return 0
101
+ return total
102
+
103
+
104
+ class FunctionalTest(tf.test.TestCase):
105
+
106
+ # TODO(tfmot): can simplify to single layer test that checks exact
107
+ # dimensions of weights.
108
+ def testSVD_ReducesSavedModelSize(self):
109
+ model = _build_model()
110
+
111
+ original_saved_model_dir = _save_as_saved_model(model)
112
+
113
+ compressed_model = svd.SVD(rank=16).compress_model(model)
114
+
115
+ saved_model_dir = _save_as_saved_model(compressed_model)
116
+
117
+ original_size = _get_directory_size_in_bytes(original_saved_model_dir)
118
+ compressed_size = _get_directory_size_in_bytes(saved_model_dir)
119
+
120
+ self.assertLess(compressed_size, original_size / 3)
121
+
122
+ def testSVD_HasReasonableAccuracy_TF(self):
123
+ model = _build_model()
124
+
125
+ compressed_model = svd.SVD(rank=16).compress_model(model)
126
+
127
+ _train_model(compressed_model)
128
+
129
+ _, (x_test, y_test) = _get_dataset()
130
+
131
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
132
+
133
+ compressed_model.compile(
134
+ optimizer='adam', loss=loss_fn, metrics=['accuracy'])
135
+
136
+ results = compressed_model.evaluate(x_test, y_test)
137
+
138
+ self.assertGreater(results[1], 0.60)
139
+
140
+ def testSVD_ReducesTFLiteModelSize(self):
141
+ model = _build_model()
142
+
143
+ original_saved_model_dir = _save_as_saved_model(model)
144
+ original_tflite_file = _convert_to_tflite(original_saved_model_dir)
145
+
146
+ compressed_model = svd.SVD(rank=16).compress_model(model)
147
+
148
+ saved_model_dir = _save_as_saved_model(compressed_model)
149
+ compressed_tflite_file = _convert_to_tflite(saved_model_dir)
150
+
151
+ original_size = os.path.getsize(original_tflite_file)
152
+ compressed_size = os.path.getsize(compressed_tflite_file)
153
+
154
+ self.assertLess(compressed_size, original_size / 6)
155
+
156
+ def testSVD_HasReasonableAccuracy_TFLite(self):
157
+ model = _build_model()
158
+
159
+ compressed_model = svd.SVD(rank=16).compress_model(model)
160
+
161
+ _train_model(compressed_model)
162
+
163
+ saved_model_dir = _save_as_saved_model(compressed_model)
164
+ compressed_tflite_file = _convert_to_tflite(saved_model_dir)
165
+
166
+ accuracy = test_utils_mnist.eval_tflite(compressed_tflite_file)
167
+
168
+ self.assertGreater(accuracy, 0.60)
169
+
170
+ # TODO(tfmot): can simplify to single layer test.
171
+ def testSVD_BreaksDownLayerWeights(self):
172
+ model = _build_model()
173
+
174
+ first_conv_layer = model.layers[2]
175
+ self.assertLen(first_conv_layer.weights, 2)
176
+
177
+ compressed_model = svd.SVD(rank=16).compress_model(model)
178
+
179
+ first_conv_layer = compressed_model.layers[2]
180
+
181
+ self.assertLen(first_conv_layer.weights, 3)
182
+
183
+ # TODO(tfmot): can simplify to single layer test.
184
+ def testSVD_PreservesPretrainedWeights(self):
185
+ i = keras.layers.Input(shape=(2), name='input')
186
+ output = keras.layers.Dense(3, name='fc1')(i)
187
+ model = keras.Model(inputs=[i], outputs=[output])
188
+
189
+ dense_layer_weights = model.layers[1].get_weights()
190
+
191
+ algorithm = svd.SVD(rank=1)
192
+ compressed_model = algorithm.compress_model(model)
193
+
194
+ dense_layer_compressed_weights = compressed_model.layers[1].get_weights()
195
+
196
+ # kernel
197
+ algorithm.weight_reprs = []
198
+ algorithm.init_training_weights(dense_layer_weights[0])
199
+ w1_repr, w2_repr = algorithm.weight_reprs
200
+ assert (w1_repr.kwargs['initializer'](None) == \
201
+ dense_layer_compressed_weights[0]).numpy().all()
202
+ assert (w2_repr.kwargs['initializer'](None) == \
203
+ dense_layer_compressed_weights[1]).numpy().all()
204
+
205
+ # bias
206
+ assert (dense_layer_weights[1] == dense_layer_compressed_weights[2]).all()
207
+
208
+
209
+ if __name__ == '__main__':
210
+ tf.test.main()