class horizon_plugin_pytorch.quantization.observer_v2.KLObserver(bins: int = 512, update_interval: int = 1, averaging_constant: float = 0.01, ch_axis: int = -1, dtype: dtype | QuantDType = 'qint8', qscheme: qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)
KL observer.
KL observer based on histogram. Histogram is calculated online
and won’t be saved.
- Parameters:
- bins – Number of histograms bins.
- update_interval – Interval of computing KL entropy and update min/max.
KLObserver will constantly collect histograms of activations,
but only perform KL calculation when update_interval is satisfied.
if it is set to 1, KL entropy will be computed every forward step.
Larger interval guarantees less time and does no harm to
calibration accuracy. Set it to the total calibration steps can
achieve best performance. update_interval must be no greater than
total calibration steps, otherwise no min/max will be computed.
- averaging_constant – Averaging constant for min/max.
- ch_axis – Channel axis.
- dtype – Quantized data type.
- qscheme – Quantization scheme to be used.
- quant_min – Min quantization value. Will follow dtype if unspecified.
- quant_max – Max quantization value. Will follow dtype if unspecified.
- is_sync_quantize – If sync statistics when training with multiple
devices.
- factory_kwargs – kwargs which are passed to factory functions for
min_val and max_val.
forward(x_orig)
Defines the computation performed at every call.
Should be overridden by all subclasses.
NOTE
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
class horizon_plugin_pytorch.quantization.observer_v2.MSEObserver(stride: int = 1, averaging_constant: float = 0.01, ch_axis: int = -1, dtype: dtype | QuantDType = 'qint8', qscheme: qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)
MSE observer.
Observer module for computing the quantization parameters based on the
Mean Square Error (MSE) between the original tensor and the quantized one.
This observer linear searches the quantization scales that minimize MSE.
- Parameters:
- stride – Searching stride. Larger value gives smaller search space,
which means less computing time but possibly poorer accuracy.
Default is 1. Suggests no greater than 20.
- averaging_constant – Averaging constant for min/max.
- ch_axis – Channel axis.
- dtype – Quantized data type.
- qscheme – Quantization scheme to be used.
- quant_min – Min quantization value. Will follow dtype if unspecified.
- quant_max – Max quantization value. Will follow dtype if unspecified.
- is_sync_quantize – If sync statistics when training with multiple
devices.
- factory_kwargs – kwargs which are passed to factory functions for
min_val and max_val.
forward(x_orig)
Defines the computation performed at every call.
Should be overridden by all subclasses.
NOTE
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
class horizon_plugin_pytorch.quantization.observer_v2.MinMaxObserver(averaging_constant: float = 0.01, ch_axis: int = -1, dtype: dtype | QuantDType = 'qint8', qscheme: qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)
Min max observer.
This observer computes the quantization parameters based on minimums and
maximums of the incoming tensors. The module records the moving average
minimum and maximum of incoming tensors, and uses this statistic to compute
the quantization parameters.
- Parameters:
- averaging_constant – Averaging constant for min/max.
- ch_axis – Channel axis.
- dtype – Quantized data type.
- qscheme – Quantization scheme to be used.
- quant_min – Min quantization value. Will follow dtype if unspecified.
- quant_max – Max quantization value. Will follow dtype if unspecified.
- is_sync_quantize – If sync statistics when training with multiple
devices.
- factory_kwargs – kwargs which are passed to factory functions for
min_val and max_val.
forward(x_orig)
Record the running minimum and maximum of x.
class horizon_plugin_pytorch.quantization.observer_v2.MixObserver(averaging_constant: float = 0.01, ch_axis: int = -1, dtype: dtype | QuantDType = 'qint8', qscheme: qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)
Mix observer.
This observer computes the quantization parameters based on multiple
calibration methods and selects the quantization parameters with the
smallest quantization error.
- Parameters:
- averaging_constant – Averaging constant for min/max.
- ch_axis – Channel axis.
- dtype – Quantized data type.
- qscheme – Quantization scheme to be used.
- quant_min – Min quantization value. Will follow dtype if unspecified.
- quant_max – Max quantization value. Will follow dtype if unspecified.
- is_sync_quantize – If sync statistics when training with multiple
devices.
- factory_kwargs – kwargs which are passed to factory functions for
min_val and max_val.
forward(x_orig)
Defines the computation performed at every call.
Should be overridden by all subclasses.
NOTE
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
class horizon_plugin_pytorch.quantization.observer_v2.PercentileObserver(percentile: float = 99.99, bins: int = 2048, averaging_constant: float = 0.01, ch_axis: int = -1, dtype: dtype | QuantDType = 'qint8', qscheme: qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)
Percentile observer.
Percentile observer based on histogram. Histogram is calculated online
and won’t be saved. The minimum and maximum are moving averaged to compute
the quantization parameters.
- Parameters:
- percentile – Index percentile of histrogram
- bins – Number of histograms bins.
- averaging_constant – Averaging constant for min/max.
- ch_axis – Channel axis.
- dtype – Quantized data type.
- qscheme – Quantization scheme to be used.
- quant_min – Min quantization value. Will follow dtype if unspecified.
- quant_max – Max quantization value. Will follow dtype if unspecified.
- is_sync_quantize – If sync statistics when training with multiple
devices.
- factory_kwargs – kwargs which are passed to factory functions for
min_val and max_val.
forward(x_orig)
Defines the computation performed at every call.
Should be overridden by all subclasses.
NOTE
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
class horizon_plugin_pytorch.quantization.MovingAverageMinMaxObserver(averaging_constant=0.01, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=None, quant_max=None, is_sync_quantize=False, factory_kwargs=None)
MovingAverageMinMax Observer.
Observer module for computing the quantization parameters based on the
moving average of the min and max values.
This observer computes the quantization parameters based on the moving
averages of minimums and maximums of the incoming tensors. The module
records the average minimum and maximum of incoming tensors, and uses this
statistic to compute the quantization parameters.
- Parameters:
- averaging_constant – Averaging constant for min/max.
- dtype – Quantized data type
- qscheme – Quantization scheme to be used, only support
per_tensor_symmetric scheme
- reduce_range – Reduces the range of the quantized data type by 1 bit
- quant_min – Minimum quantization value.
- quant_max – Maximum quantization value.
- is_sync_quantize – Whether use sync quantize
- factory_kwargs – Arguments for register data buffer
forward(x_orig)
Record the running minimum and maximum of x.
class horizon_plugin_pytorch.quantization.MovingAveragePerChannelMinMaxObserver(averaging_constant=0.01, ch_axis=0, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, quant_min=None, quant_max=None, is_sync_quantize=False, factory_kwargs=None)
MovingAveragePerChannelMinMax Observer.
Observer module for computing the quantization parameters based on the
running per channel min and max values.
This observer uses the tensor min/max statistics to compute the per channel
quantization parameters. The module records the running minimum and maximum
of incoming tensors, and uses this statistic to compute the quantization
parameters.
- Parameters:
- averaging_constant – Averaging constant for min/max.
- ch_axis – Channel axis
- dtype – Quantized data type
- qscheme – Quantization scheme to be used, Only support
per_channel_symmetric
- quant_min – Minimum quantization value.
- quant_max – Maximum quantization value.
- is_sync_quantize – whether use sync quantize
- factory_kwargs – Arguments for register data buffer
forward(x_orig)
Defines the computation performed at every call.
Should be overridden by all subclasses.
NOTE
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.