class horizon_plugin_pytorch.quantization.observer_v2.KLObserver(bins: int = 512, update_interval: int = 1, averaging_constant: float = 0.01, ch_axis: int = -1, dtype: dtype | QuantDType = 'qint8', qscheme: qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)

KL observer.

KL observer based on histogram. Histogram is calculated online and won’t be saved.

  • Parameters:
    • bins – Number of histograms bins.
    • update_interval – Interval of computing KL entropy and update min/max. KLObserver will constantly collect histograms of activations, but only perform KL calculation when update_interval is satisfied. if it is set to 1, KL entropy will be computed every forward step. Larger interval guarantees less time and does no harm to calibration accuracy. Set it to the total calibration steps can achieve best performance. update_interval must be no greater than total calibration steps, otherwise no min/max will be computed.
    • averaging_constant – Averaging constant for min/max.
    • ch_axis – Channel axis.
    • dtype – Quantized data type.
    • qscheme – Quantization scheme to be used.
    • quant_min – Min quantization value. Will follow dtype if unspecified.
    • quant_max – Max quantization value. Will follow dtype if unspecified.
    • is_sync_quantize – If sync statistics when training with multiple devices.
    • factory_kwargs – kwargs which are passed to factory functions for min_val and max_val.

forward(x_orig)

Defines the computation performed at every call.

Should be overridden by all subclasses.

NOTE

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class horizon_plugin_pytorch.quantization.observer_v2.MSEObserver(stride: int = 1, averaging_constant: float = 0.01, ch_axis: int = -1, dtype: dtype | QuantDType = 'qint8', qscheme: qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)

MSE observer.

Observer module for computing the quantization parameters based on the Mean Square Error (MSE) between the original tensor and the quantized one.

This observer linear searches the quantization scales that minimize MSE.

  • Parameters:
    • stride – Searching stride. Larger value gives smaller search space, which means less computing time but possibly poorer accuracy. Default is 1. Suggests no greater than 20.
    • averaging_constant – Averaging constant for min/max.
    • ch_axis – Channel axis.
    • dtype – Quantized data type.
    • qscheme – Quantization scheme to be used.
    • quant_min – Min quantization value. Will follow dtype if unspecified.
    • quant_max – Max quantization value. Will follow dtype if unspecified.
    • is_sync_quantize – If sync statistics when training with multiple devices.
    • factory_kwargs – kwargs which are passed to factory functions for min_val and max_val.

forward(x_orig)

Defines the computation performed at every call.

Should be overridden by all subclasses.

NOTE

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class horizon_plugin_pytorch.quantization.observer_v2.MinMaxObserver(averaging_constant: float = 0.01, ch_axis: int = -1, dtype: dtype | QuantDType = 'qint8', qscheme: qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)

Min max observer.

This observer computes the quantization parameters based on minimums and maximums of the incoming tensors. The module records the moving average minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.

  • Parameters:
    • averaging_constant – Averaging constant for min/max.
    • ch_axis – Channel axis.
    • dtype – Quantized data type.
    • qscheme – Quantization scheme to be used.
    • quant_min – Min quantization value. Will follow dtype if unspecified.
    • quant_max – Max quantization value. Will follow dtype if unspecified.
    • is_sync_quantize – If sync statistics when training with multiple devices.
    • factory_kwargs – kwargs which are passed to factory functions for min_val and max_val.

forward(x_orig)

Record the running minimum and maximum of x.

class horizon_plugin_pytorch.quantization.observer_v2.MixObserver(averaging_constant: float = 0.01, ch_axis: int = -1, dtype: dtype | QuantDType = 'qint8', qscheme: qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)

Mix observer.

This observer computes the quantization parameters based on multiple calibration methods and selects the quantization parameters with the smallest quantization error.

  • Parameters:
    • averaging_constant – Averaging constant for min/max.
    • ch_axis – Channel axis.
    • dtype – Quantized data type.
    • qscheme – Quantization scheme to be used.
    • quant_min – Min quantization value. Will follow dtype if unspecified.
    • quant_max – Max quantization value. Will follow dtype if unspecified.
    • is_sync_quantize – If sync statistics when training with multiple devices.
    • factory_kwargs – kwargs which are passed to factory functions for min_val and max_val.

forward(x_orig)

Defines the computation performed at every call.

Should be overridden by all subclasses.

NOTE

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class horizon_plugin_pytorch.quantization.observer_v2.PercentileObserver(percentile: float = 99.99, bins: int = 2048, averaging_constant: float = 0.01, ch_axis: int = -1, dtype: dtype | QuantDType = 'qint8', qscheme: qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)

Percentile observer.

Percentile observer based on histogram. Histogram is calculated online and won’t be saved. The minimum and maximum are moving averaged to compute the quantization parameters.

  • Parameters:
    • percentile – Index percentile of histrogram
    • bins – Number of histograms bins.
    • averaging_constant – Averaging constant for min/max.
    • ch_axis – Channel axis.
    • dtype – Quantized data type.
    • qscheme – Quantization scheme to be used.
    • quant_min – Min quantization value. Will follow dtype if unspecified.
    • quant_max – Max quantization value. Will follow dtype if unspecified.
    • is_sync_quantize – If sync statistics when training with multiple devices.
    • factory_kwargs – kwargs which are passed to factory functions for min_val and max_val.

forward(x_orig)

Defines the computation performed at every call.

Should be overridden by all subclasses.

NOTE

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class horizon_plugin_pytorch.quantization.MovingAverageMinMaxObserver(averaging_constant=0.01, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=None, quant_max=None, is_sync_quantize=False, factory_kwargs=None)

MovingAverageMinMax Observer.

Observer module for computing the quantization parameters based on the moving average of the min and max values.

This observer computes the quantization parameters based on the moving averages of minimums and maximums of the incoming tensors. The module records the average minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.

  • Parameters:
    • averaging_constant – Averaging constant for min/max.
    • dtype – Quantized data type
    • qscheme – Quantization scheme to be used, only support per_tensor_symmetric scheme
    • reduce_range – Reduces the range of the quantized data type by 1 bit
    • quant_min – Minimum quantization value.
    • quant_max – Maximum quantization value.
    • is_sync_quantize – Whether use sync quantize
    • factory_kwargs – Arguments for register data buffer

forward(x_orig)

Record the running minimum and maximum of x.

class horizon_plugin_pytorch.quantization.MovingAveragePerChannelMinMaxObserver(averaging_constant=0.01, ch_axis=0, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, quant_min=None, quant_max=None, is_sync_quantize=False, factory_kwargs=None)

MovingAveragePerChannelMinMax Observer.

Observer module for computing the quantization parameters based on the running per channel min and max values.

This observer uses the tensor min/max statistics to compute the per channel quantization parameters. The module records the running minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.

  • Parameters:
    • averaging_constant – Averaging constant for min/max.
    • ch_axis – Channel axis
    • dtype – Quantized data type
    • qscheme – Quantization scheme to be used, Only support per_channel_symmetric
    • quant_min – Minimum quantization value.
    • quant_max – Maximum quantization value.
    • is_sync_quantize – whether use sync quantize
    • factory_kwargs – Arguments for register data buffer

forward(x_orig)

Defines the computation performed at every call.

Should be overridden by all subclasses.

NOTE

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

On This Page