from __future__ import absolute_import, division, print_function, unicode_literals
import math
import warnings
from abc import ABCMeta, abstractmethod
from functools import partial
import torch
import torch.nn as nn
from torch._jit_internal import List, Optional
def _with_args(cls_or_self, **kwargs):
r"""Wrapper that allows creation of class factories.
This can be useful when there is a need to create classes with the same
constructor arguments, but different instances.
.. Example::
>>> Foo.with_args = classmethod(_with_args)
>>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42)
>>> foo_instance1 = foo_builder()
>>> foo_instance2 = foo_builder()
>>> id(foo_instance1) == id(foo_instance2)
False
"""
class _PartialWrapper(object):
def __init__(self, p):
self.p = p
def __call__(self, *args, **keywords):
return self.p(*args, **keywords)
def __repr__(self):
return self.p.__repr__()
with_args = _with_args
r = _PartialWrapper(partial(cls_or_self, **kwargs))
return r
ABC = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3:
class ObserverBase(ABC, nn.Module):
r"""Base observer Module.
Any observer implementation should derive from this class.
Concrete observers should follow the same API. In forward, they will update
the statistics of the observed Tensor. And they should provide a
`calculate_qparams` function that computes the quantization parameters given
the collected statistics.
Args:
dtype: Quantized data type
"""
def __init__(self, dtype):
super(ObserverBase, self).__init__()
self.dtype = dtype
@abstractmethod
def forward(self, x):
pass
@abstractmethod
def calculate_qparams(self, **kwargs):
pass
with_args = classmethod(_with_args)
class _ObserverBase(ObserverBase):
r"""Internal common base for all qint/quint8 observers.
This base is for commonly used paramters used internally.
Users should use `~torch.quantization.observer.ObserverBase` as a base class
for custom observers.
Args:
dtype: Quantized data type.
qscheme: Quantization scheme to be used.
reduce_range: Reduces the range of the quantized data type by 1 bit.
This is sometimes required to avoid instruction overflow.
.. warning::
:attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``.
.. warning::
:attr:`qscheme` can only take one of the following options:
- ``torch.per_tensor_affine``
- ``torch.per_tensor_symmetric``
- ``torch.per_channel_affine``
- ``torch.per_channel_symmetric``
"""
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
reduce_range=False):
super(_ObserverBase, self).__init__(dtype=dtype)
self.qscheme = qscheme
self.reduce_range = reduce_range
self.eps = torch.finfo(torch.float32).eps
assert self.qscheme in (
torch.per_tensor_affine,
torch.per_tensor_symmetric,
torch.per_channel_affine,
torch.per_channel_symmetric,
), "Default Observer only works for per_tensor_affine, \
per_tensor_symmetric, per_channel_affine and \
per_channel_symmetric quantization scheme"
assert self.dtype in (
torch.qint8,
torch.quint8,
), "Default Observer only works for qint8 and quint8 data type"
def _calculate_per_channel_qparams(self, min_vals, max_vals):
# type: (Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor, int]
r"""Calculates the per channel quantization parameters, given min and max
value tensors.
Args:
min_vals: Minimum values per channel
max_vals: Maximum values per channel
Returns:
scales: Per channel scales tensor of shape (#channels,)
zero_points: Per channel zero points tensor of shape (#channels,)
"""
if min_vals is None or max_vals is None:
warnings.warn(
"must run observer before calling calculate_qparams.\
Returning default scale and zero point "
)
return torch.tensor([1.0]), torch.tensor([0]), self.ch_axis
for i in range(len(min_vals)):
assert (
min_vals[i] <= max_vals[i]
), "min {} should be less than max {}".format(min_vals[i], max_vals[i])
scales = torch.empty(min_vals.size(), dtype=torch.float32)
zero_points = torch.empty(min_vals.size(), dtype=torch.int64)
for i in range(len(scales)):
qparam = self._calculate_qparams(
min_vals[i], max_vals[i]
)
scales[i] = float(qparam[0])
zero_points[i] = int(qparam[1])
return scales, zero_points, self.ch_axis
@torch.jit.export
def _calculate_qparams(self, min_val, max_val):
# type: (Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
r"""Calculates the per tensor quantization parameters, given the min/max.
Args:
min_val: Per tensor minimum value
max_val: Per tensor maximum value
Returns:
scale: Scale as a tensor of shape (1,)
zero_point: Zero point as a tensor of shape (1,)
"""
if max_val is None or min_val is None:
warnings.warn("Must run observer before calling calculate_qparams.\
Returning default scale and zero point.")
return torch.tensor([1.0]), torch.tensor([0])
assert min_val <= max_val, "min {} should be less than max {}".format(
min_val, max_val
)
if self.dtype == torch.qint8:
if self.reduce_range:
qmin, qmax = -64, 63
else:
qmin, qmax = -128, 127
else:
if self.reduce_range:
qmin, qmax = 0, 127
else:
qmin, qmax = 0, 255
max_val, min_val = float(max_val), float(min_val)
min_val = min(0.0, min_val)
max_val = max(0.0, max_val)
if max_val == min_val:
scale = 1.0
zero_point = 0
else:
if self.qscheme == torch.per_tensor_symmetric or self.qscheme == torch.per_channel_symmetric:
max_val = max(-min_val, max_val)
scale = max_val / ((qmax - qmin) / 2)
scale = max(scale, self.eps)
zero_point = 0 if self.dtype == torch.qint8 else 128
else:
scale = (max_val - min_val) / float(qmax - qmin)
scale = max(scale, self.eps)
zero_point = qmin - round(min_val / scale)
zero_point = max(qmin, zero_point)
zero_point = min(qmax, zero_point)
zero_point = int(zero_point)
return torch.tensor([scale]), torch.tensor([zero_point])
[docs]class MinMaxObserver(_ObserverBase):
r"""Observer module for computing the quantization parameters based on the
running min and max values.
This observer uses the tensor min/max statistics to compute the quantization
parameters. The module records the running minimum and maximum of incoming
tensors, and uses this statistic to compute the quantization parameters.
Args:
dtype: Quantized data type
qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit
Given running min/max as :math:`x_\text{min}` and :math:`x_\text{max}`,
scale :math:`s` and zero point :math:`z` are computed as:
The running minimum/maximum :math:`x_\text{min/max}` is computed as:
.. math::
\begin{array}{ll}
x_\text{min} &= \begin{cases}
\min(X) & \text{if~}x_\text{min} = \text{None} \\
\min\left(x_\text{min}, \min(X)\right) & \text{otherwise}
\end{cases}\\
x_\text{max} &= \begin{cases}
\max(X) & \text{if~}x_\text{max} = \text{None} \\
\max\left(x_\text{max}, \max(X)\right) & \text{otherwise}
\end{cases}\\
\end{array}
where :math:`X` is the observed tensor.
The scale :math:`s` and zero point :math:`z` are then computed as:
.. math::
\begin{aligned}
\text{if Symmetric:}&\\
&s = 2 \max(|x_\text{min}|, x_\text{max}) /
\left( Q_\text{max} - Q_\text{min} \right) \\
&z = \begin{cases}
0 & \text{if dtype is qint8} \\
128 & \text{otherwise}
\end{cases}\\
\text{Otherwise:}&\\
&s = \left( x_\text{max} - x_\text{min} \right ) /
\left( Q_\text{max} - Q_\text{min} \right ) \\
&z = Q_\text{min} - \text{round}(x_\text{min} / s)
\end{aligned}
where :math:`Q_\text{min}` and :math:`Q_\text{max}` are the minimum and
maximum of the quantized data type.
.. warning:: Only works with ``torch.per_tensor_symmetric`` quantization scheme
.. warning:: :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``.
.. note:: If the running minimum equals to the running maximum, the scale
and zero_point are set to 1.0 and 0.
"""
__annotations__ = {
"min_val": Optional[torch.Tensor],
"max_val": Optional[torch.Tensor],
}
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
reduce_range=False):
# For x86 quantized kernels, we need to ensure that the vpmaddubsw
# instruction does not overflow. We allow for a reduce_range argument to
# observers that reduces the quantized range to (0,127) or (-64, 63).
# For more details see aten/src/ATen/native/quantized/cpu/qconv.cpp
# This is not an optimal choice for non x86 backends as it loses a bit
# of precision for activations.
super(MinMaxObserver, self).__init__(dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range)
self.min_val = None
self.max_val = None
if self.qscheme == torch.per_tensor_symmetric and \
self.reduce_range and \
self.dtype == torch.quint8:
raise NotImplementedError("Cannot reduce range for symmetric \
quantization for quint8")
def forward(self, x_orig):
r"""Records the running minimum and maximum of ``x``."""
x = x_orig.detach() # avoid keeping autograd tape
min_val = self.min_val
max_val = self.max_val
if min_val is None or max_val is None:
min_val = torch.min(x)
max_val = torch.max(x)
else:
min_val = torch.min(torch.min(x), min_val)
max_val = torch.max(torch.max(x), max_val)
self.min_val = min_val
self.max_val = max_val
return x_orig
@torch.jit.export
def calculate_qparams(self):
r"""Calculates the quantization parameters."""
return self._calculate_qparams(self.min_val, self.max_val)
@torch.jit.export
def extra_repr(self):
return "min_val={}, max_val={}".format(self.min_val, self.max_val)
def _save_to_state_dict(self, destination, prefix, keep_vars):
super(MinMaxObserver, self)._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'min_val'] = self.min_val
destination[prefix + 'max_val'] = self.max_val
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
self.min_val = state_dict.pop(prefix + 'min_val')
self.max_val = state_dict.pop(prefix + 'max_val')
super(MinMaxObserver, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)
[docs]class MovingAverageMinMaxObserver(MinMaxObserver):
r"""Observer module for computing the quantization parameters based on the
moving average of the min and max values.
This observer computes the quantization parameters based on the moving
averages of minimums and maximums of the incoming tensors. The module
records the average minimum and maximum of incoming tensors, and uses this
statistic to compute the quantization parameters.
Args:
averaging_constant: Averaging constant for min/max.
dtype: Quantized data type
qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit
The moving average min/max is computed as follows
.. math::
\begin{array}{ll}
x_\text{min} = \begin{cases}
\min(X) & \text{if~}x_\text{min} = \text{None} \\
(1 - c) x_\text{min} + c \min(X) & \text{otherwise}
\end{cases}\\
x_\text{max} = \begin{cases}
\max(X) & \text{if~}x_\text{max} = \text{None} \\
(1 - c) x_\text{max} + c \max(X) & \text{otherwise}
\end{cases}\\
\end{array}
where :math:`x_\text{min/max}` is the running average min/max, :math:`X` is
is the incoming tensor, and :math:`c` is the ``averaging_constant``.
The scale and zero point are then computed as in
:class:`~torch.quantization.observer.MinMaxObserver`.
.. note:: Only works with ``torch.per_tensor_affine`` quantization shceme.
.. note:: If the running minimum equals to the running maximum, the scale
and zero_point are set to 1.0 and 0.
"""
def __init__(self, averaging_constant=0.01, dtype=torch.quint8,
qscheme=torch.per_tensor_affine, reduce_range=False):
self.averaging_constant = averaging_constant
super(MovingAverageMinMaxObserver, self).__init__(dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range)
def forward(self, x_orig):
x = x_orig.detach() # avoid keeping autograd tape
min_val = self.min_val
max_val = self.max_val
if min_val is None or max_val is None:
min_val = torch.min(x)
max_val = torch.max(x)
else:
min_val = min_val + self.averaging_constant * (torch.min(x) - min_val)
max_val = max_val + self.averaging_constant * (torch.max(x) - max_val)
self.min_val = min_val
self.max_val = max_val
return x_orig
[docs]class PerChannelMinMaxObserver(_ObserverBase):
r"""Observer module for computing the quantization parameters based on the
running per channel min and max values.
This observer uses the tensor min/max statistics to compute the per channel
quantization parameters. The module records the running minimum and maximum
of incoming tensors, and uses this statistic to compute the quantization
parameters.
Args:
ch_axis: Channel axis
dtype: Quantized data type
qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit
The quantization parameters are computed the same way as in
:class:`~torch.quantization.observer.MinMaxObserver`, with the difference
that the running min/max values are stored per channel.
Scales and zero points are thus computed per channel as well.
.. note:: If the running minimum equals to the running maximum, the scales
and zero_points are set to 1.0 and 0.
"""
__annotations__ = {
"min_vals": Optional[torch.Tensor],
"max_vals": Optional[torch.Tensor],
}
def __init__(self, ch_axis=0, dtype=torch.quint8,
qscheme=torch.per_channel_affine, reduce_range=False):
super(PerChannelMinMaxObserver, self).__init__(dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range)
self.ch_axis = ch_axis
self.min_vals = None
self.max_vals = None
if (
self.qscheme == torch.per_channel_symmetric
and self.reduce_range
and self.dtype == torch.quint8
):
raise NotImplementedError(
"Cannot reduce range for symmetric quantization for quint8"
)
def forward(self, x_orig):
return self._forward(x_orig)
@torch.jit.ignore
def _forward(self, x_orig):
x = x_orig.detach() # avoid keeping autograd tape
min_vals = self.min_vals
max_vals = self.max_vals
x_dim = x.size()
new_axis_list = list(range(len(x_dim)))
new_axis_list[self.ch_axis] = 0
new_axis_list[0] = self.ch_axis
y = x.permute(tuple(new_axis_list))
y = torch.flatten(y, start_dim=1)
if min_vals is None or max_vals is None:
min_vals = torch.min(y, 1)[0]
max_vals = torch.max(y, 1)[0]
else:
min_vals = torch.min(torch.min(y, 1)[0], min_vals)
max_vals = torch.max(torch.max(y, 1)[0], max_vals)
self.min_vals = min_vals
self.max_vals = max_vals
return x_orig
@torch.jit.export
def calculate_qparams(self):
return self._calculate_per_channel_qparams(self.min_vals, self.max_vals)
def extra_repr(self):
return "min_val={}, max_val={}".format(self.min_vals, self.max_vals)
def _save_to_state_dict(self, destination, prefix, keep_vars):
super(PerChannelMinMaxObserver, self)._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'min_vals'] = self.min_vals
destination[prefix + 'max_vals'] = self.max_vals
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
self.min_vals = state_dict.pop(prefix + 'min_vals')
self.max_vals = state_dict.pop(prefix + 'max_vals')
super(PerChannelMinMaxObserver, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)
[docs]class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver):
r"""Observer module for computing the quantization parameters based on the
running per channel min and max values.
This observer uses the tensor min/max statistics to compute the per channel
quantization parameters. The module records the running minimum and maximum
of incoming tensors, and uses this statistic to compute the quantization
parameters.
Args:
averaging_constant: Averaging constant for min/max.
ch_axis: Channel axis
dtype: Quantized data type
qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit
The quantization parameters are computed the same way as in
:class:`~torch.quantization.observer.MovingAverageMinMaxObserver`, with the
difference that the running min/max values are stored per channel.
Scales and zero points are thus computed per channel as well.
.. note:: If the running minimum equals to the running maximum, the scales
and zero_points are set to 1.0 and 0.
"""
def __init__(self, averaging_constant=0.01, ch_axis=0, dtype=torch.quint8,
qscheme=torch.per_channel_affine, reduce_range=False):
super(MovingAveragePerChannelMinMaxObserver, self).__init__(
ch_axis=ch_axis, dtype=dtype, qscheme=qscheme,
reduce_range=reduce_range)
self.averaging_constant = averaging_constant
def forward(self, x_orig):
x = x_orig.detach() # avoid keeping autograd tape
min_vals = self.min_vals
max_vals = self.max_vals
x_dim = x.size()
new_axis_list = list(range(len(x_dim)))
new_axis_list[self.ch_axis] = 0
new_axis_list[0] = self.ch_axis
y = x.permute(tuple(new_axis_list))
y = torch.flatten(y, start_dim=1)
if min_vals is None or max_vals is None:
min_vals = torch.min(y, 1)[0]
max_vals = torch.max(y, 1)[0]
else:
min_vals = min_vals + self.averaging_constant * (torch.min(y, 1)[0] - min_vals)
max_vals = max_vals + self.averaging_constant * (torch.max(y, 1)[0] - max_vals)
self.min_vals = min_vals
self.max_vals = max_vals
return x_orig
[docs]class HistogramObserver(_ObserverBase):
r"""
The module records the running histogram of tensor values along with
min/max values. ``calculate_qparams`` will calculate scale and zero_point.
Args:
bins: Number of bins to use for the histogram
upsample_rate: Factor by which the histograms are upsampled, this is
used to interpolate histograms with varying ranges across observations
dtype: Quantized data type
qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit
The scale and zero point are computed as follows:
1. Create the histogram of the incoming inputs.
The histogram is computed continuously, and the ranges per bin change
with every new tensor observed.
2. Search the distribution in the histogram for optimal min/max values.
The search for the min/max values ensures the minimization of the
quantization error with respect to the floating point model.
3. Compute the scale and zero point the same way as in the
:class:`~torch.quantization.MinMaxObserver`
"""
__annotations__ = {
"min_val": Optional[torch.Tensor],
"max_val": Optional[torch.Tensor],
}
def __init__(self, bins=2048, upsample_rate=128, dtype=torch.quint8,
qscheme=torch.per_tensor_affine, reduce_range=False):
# bins: The number of bins used for histogram calculation.
super(HistogramObserver, self).__init__(dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range)
self.bins = bins
self.register_buffer('histogram', torch.zeros(self.bins))
self.min_val = None
self.max_val = None
self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits
self.upsample_rate = upsample_rate
@torch.jit.ignore
def _non_linear_param_search(self):
r"""Non-linear parameter search.
An approximation for L2 error minimization for selecting min/max.
By selecting new min/max, we filter out outliers in input distribution.
This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in
caffe2/quantization/server/norm_minimization.cc
"""
def _get_norm(delta_begin, delta_end, density, norm_type):
r"""
Compute the norm of the values uniformaly distributed between
delta_begin and delta_end.
norm = density * (integral_{begin, end} x^2)
= density * (end^3 - begin^3) / 3
"""
assert norm_type == "L2", "Only L2 norms are currently supported"
norm = 0.0
if norm_type == "L2":
norm = (
delta_end * delta_end * delta_end
- delta_begin * delta_begin * delta_begin
) / 3
return density * norm
def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
r"""
Compute the quantization error if we use start_bin to end_bin as the
min and max to do the quantization.
"""
bin_width = (self.max_val.item() - self.min_val.item()) / self.bins
norm = 0.0
dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
if dst_bin_width == 0.0:
return 0.0
for src_bin in range(self.bins):
# distances from the beginning of first dst_bin to the beginning and
# end of src_bin
src_bin_begin = (src_bin - next_start_bin) * bin_width
src_bin_end = src_bin_begin + bin_width
# which dst_bins the beginning and end of src_bin belong to?
dst_bin_of_begin = min(
self.dst_nbins - 1, max(0.0, math.floor(src_bin_begin / dst_bin_width))
)
dst_bin_of_end = min(
self.dst_nbins - 1, max(0.0, math.floor(src_bin_end / dst_bin_width))
)
dst_bin_of_begin_center = (
dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
)
density = self.histogram[src_bin] / bin_width
if dst_bin_of_begin == dst_bin_of_end:
# if src_bin is entirely within 1 dst_bin
delta_begin = src_bin_begin - dst_bin_of_begin_center
delta_end = src_bin_end - dst_bin_of_begin_center
norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
else:
delta_begin = src_bin_begin - dst_bin_of_begin_center
delta_end = dst_bin_width / 2
norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
-dst_bin_width / 2, dst_bin_width / 2, density, norm_type
)
dst_bin_of_end_center = (
dst_bin_of_end * dst_bin_width + dst_bin_width / 2
)
delta_begin = -dst_bin_width / 2
delta_end = src_bin_end - dst_bin_of_end_center
norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
return norm
assert self.histogram.size()[0] == self.bins, "bins mistmatch"
bin_width = (self.max_val - self.min_val) / self.bins
# cumulative sum
total = sum(self.histogram)
cSum = torch.cumsum(self.histogram, dim=0)
stepsize = 1e-5 # granularity
alpha = 0.0 # lower bound
beta = 1.0 # upper bound
start_bin = 0
end_bin = self.bins - 1
norm_min = float("inf")
while alpha < beta:
# Find the next step
next_alpha = alpha + stepsize
next_beta = beta - stepsize
# find the left and right bins between the quantile bounds
l = start_bin
r = end_bin
while l < end_bin and cSum[l] < next_alpha * total:
l = l + 1
while r > start_bin and cSum[r] > next_beta * total:
r = r - 1
# decide the next move
next_start_bin = start_bin
next_end_bin = end_bin
if (l - start_bin) > (end_bin - r):
# move the start bin
next_start_bin = l
alpha = next_alpha
else:
# move the end bin
next_end_bin = r
beta = next_beta
if next_start_bin == start_bin and next_end_bin == end_bin:
continue
# calculate the quantization error using next_start_bin and next_end_bin
norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
if norm > norm_min:
break
norm_min = norm
start_bin = next_start_bin
end_bin = next_end_bin
new_min = self.min_val + bin_width * start_bin
new_max = self.min_val + bin_width * (end_bin + 1)
return new_min, new_max
@torch.jit.ignore
def _adjust_min_max(self, combined_min, combined_max, upsample_rate):
# type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor, int, int]
# We ensure that:
# (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
# This allows us to have a common grid of resolution s, where we can align
# the input histogram
# start_idx maps min_val to the histogram bin index.
hist_bin_width = (self.max_val - self.min_val) / (self.bins * upsample_rate)
downsample_rate = torch.ceil((combined_max - combined_min) / (self.bins * hist_bin_width)).to(torch.int).item()
e = downsample_rate * (self.bins * hist_bin_width) - (combined_max - combined_min)
combined_max = combined_max + e / 2
combined_min = combined_min - e / 2
start_idx = torch.round((self.min_val - combined_min) / hist_bin_width).to(torch.int).item()
return combined_min, combined_max, downsample_rate, start_idx
@torch.jit.ignore
def _combine_histograms(self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins):
# type: (Tensor, Tensor, int, int, int, int) -> Tensor
# First up-sample the histogram with new data by a factor of L
# This creates an approximate probability density thats piecwise constant
upsampled_histogram = new_hist.repeat_interleave(upsample_rate)
# Now insert the upsampled histogram into the output
# histogram, which is initialized with zeros.
# The offset at which the histogram is introduced is determined
# by the start index as the output histogram can cover a wider range
histogram_with_output_range = torch.zeros((Nbins * downsample_rate))
histogram_with_output_range[start_idx:Nbins * upsample_rate + start_idx] = upsampled_histogram
# Compute integral histogram, double precision is needed to ensure
# that there are no overflows
integral_histogram = torch.cumsum(histogram_with_output_range, 0,
dtype=torch.double)[downsample_rate - 1 :: downsample_rate]
# Finally perform interpolation
shifted_integral_histogram = torch.zeros((Nbins))
shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1]
interpolated_histogram = (integral_histogram - shifted_integral_histogram) / upsample_rate
orig_hist = orig_hist + interpolated_histogram.to(torch.float)
return orig_hist
def forward(self, x_orig):
# type: (Tensor) -> Tensor
x = x_orig.detach()
min_val = self.min_val
max_val = self.max_val
if min_val is None or max_val is None:
min_val = torch.min(x)
max_val = torch.max(x)
self.min_val = min_val
self.max_val = max_val
self.histogram = torch.histc(x, self.bins, min=min_val, max=max_val)
else:
new_min = torch.min(x)
new_max = torch.max(x)
combined_min = torch.min(new_min, min_val)
combined_max = torch.max(new_max, max_val)
# combine the existing histogram and new histogram into 1 histogram
# We do this by first upsampling the histogram to a dense grid
# and then downsampling the histogram efficiently
combined_min, combined_max, downsample_rate, start_idx = \
self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
combined_histogram = torch.histc(x, self.bins, min=combined_min, max=combined_max)
if combined_min == min_val and combined_max == max_val:
combined_histogram += self.histogram
else:
combined_histogram = self._combine_histograms(
combined_histogram,
self.histogram,
self.upsample_rate,
downsample_rate,
start_idx,
self.bins)
self.histogram = combined_histogram
self.min_val = combined_min
self.max_val = combined_max
return x
@torch.jit.export
def calculate_qparams(self):
if self.min_val is None or self.max_val is None:
warnings.warn(
"must run observer before calling calculate_qparams.\
Returning default scale and zero point "
)
return torch.tensor([1.0]), torch.tensor([0])
assert self.bins == len(self.histogram), (
"The number of bins in histogram should be equal to the number of bins "
"supplied while making this observer"
)
new_min, new_max = self._non_linear_param_search()
return self._calculate_qparams(new_min, new_max)
def _save_to_state_dict(self, destination, prefix, keep_vars):
super(HistogramObserver, self)._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'min_val'] = self.min_val
destination[prefix + 'max_val'] = self.max_val
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
self.min_val = state_dict.pop(prefix + 'min_val')
self.max_val = state_dict.pop(prefix + 'max_val')
super(HistogramObserver, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)
[docs]class RecordingObserver(_ObserverBase):
r"""
The module is mainly for debug and records the tensor values during runtime.
Args:
dtype: Quantized data type
qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit
"""
__annotations__ = {"tensor_val": List[Optional[torch.Tensor]]}
def __init__(self, **kwargs):
super(RecordingObserver, self).__init__(**kwargs)
self.tensor_val = []
def forward(self, x):
self.tensor_val.append(x.clone())
return x
@torch.jit.export
def calculate_qparams(self):
raise Exception("calculate_qparams should not be called for RecordingObserver")
@torch.jit.export
def get_tensor_value(self):
return self.tensor_val
[docs]class NoopObserver(ObserverBase):
r"""
Observer that doesn't do anything and just passes its configuration to the
quantized module's ``.from_float()``.
Primarily used for quantization to float16 which doesn't require determining
ranges.
Args:
dtype: Quantized data type
"""
def __init__(self, dtype=torch.float16):
if dtype != torch.float16:
raise ValueError("Only float16 quantization can be used without calibration process")
super(NoopObserver, self).__init__(dtype=dtype)
def forward(self, x):
return x
def calculate_qparams(self):
raise Exception("calculate_qparams should not be called for NoopObserver")
# Restrict activations to be in the range (0,127)
default_observer = MinMaxObserver.with_args(reduce_range=True)
default_debug_observer = RecordingObserver
default_weight_observer = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)
default_histogram_observer = HistogramObserver.with_args(reduce_range=True)
default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)