The training can be done with either Pytorch or Keras, but Pytorch is recommended because of the unreliable API of Tensorflow (the foundation of Keras)
Make sure you install all requirements:
pip install -r requirements.txt
Additionally for verification (we don't trust Tensorflow/Keras/Pytorch blindly) the Neural Network Extension package
must be installed. The source code is in the python
folder (from the repository root). There are also
precompiled python wheels on Github available.
To train the network simply run in a terminal:
python train_torch.py
python quantize.py
or use the Jupyter notebook to train the network.
For quantisation fixed point quantisation has been used.
Network | Accuracy |
---|---|
Float | 0.9832 |
Fake Quant | 0.9832 |
Quant: | 0.9349 |