metrics
Metrics widely used for different datasets in HAT.
metrics
| Member | Summary |
|---|
acc.Accuracy | Computes accuracy classification score. |
acc.AccuracySeg | Computes seg accuracy. |
acc.TopKAccuracy | Computes top k predictions accuracy. |
acc.AccuracyAttrMultiLabel | Computes multi-label accuracy classification score. |
argoverse2_metric.Brier | |
argoverse2_metric.minADE | |
argoverse2_metric.minAHE | |
argoverse2_metric.minFDE | |
argoverse2_metric.minFHE | |
argoverse2_metric.MR | |
argoverse2_metric.HitRate | |
argoverse_metric.ArgoverseMetric | Evaluation Argoverse Detection. |
coco_detection.COCODetectionMetric | Evaluation in COCO protocol. |
kitti2d_detection.Kitti2DMetric | Kitti2D detection metric. |
kitti3d_detection.Kitti3DMetricDet | |
loss_show.LossShow | Show loss. |
mean_iou.MeanIOU | Evaluation segmentation results. |
metric_keypoints.MeanKeypointDist | This metric calculates the mean distance between keypoints. |
metric_keypoints.PCKMetric | Compute PCK (Proportion of Correct Keypoints) metric. |
metric_lane_detection.CulaneF1Score | Metric for Lane detection task, using for Culanedataset. |
metric_optical_flow.EndPointError | Metric for OpticalFlow task, endpoint error (EPE). |
mot_metrics.MotMetric | Evaluation in MOT. |
nuscenes_map_metric.NuscenesMapMetric | Evaluation Nuscenes Detection. |
nuscenes_metric.NuscenesMetric | Evaluation Nuscenes Detection. |
nuscenes_metric.NuscenesMonoMetric | Evaluation Nuscenes Detection for mono. |
voc_detection.VOCMApMetric | Calculate mean AP for object detection task. |
voc_detection.VOC07MApMetric | Mean average precision metric for PASCAL V0C 07 dataset. |
API Reference
class hat.metrics.acc.Accuracy(axis=1, name='accuracy')
Computes accuracy classification score.
- Parameters:
- axis (int) – The axis that represents classes
- name (str) – Name of this metric instance for display.
update(labels, preds)
Override this method to update the state variables.
class hat.metrics.acc.AccuracyAttrMultiLabel(name: str = 'accuracy', attr_type_name: str = '', attr_type_list: List | None = None, attr_type_numcls: List | None = None, ignore_idx: List | None = None)
Computes multi-label accuracy classification score.
- Parameters:
- name (str) – Name of this metric instance for display.
- attr_type_name (str) – Name of the specific type for display.
- attr_type_list (List) – List of all types.
- attr_type_numcls (List) – Number of categories for each type.
- ignore_idx (List) – The index of the category to be ignored.
update(labels, preds)
Override this method to update the state variables.
class hat.metrics.acc.AccuracySeg(name='accuracy', axis=1)
Computes seg accuracy.
update(output)
Override this method to update the state variables.
class hat.metrics.acc.TopKAccuracy(top_k, name='top_k_accuracy')
Computes top k predictions accuracy.
TopKAccuracy differs from Accuracy in that it considers the prediction
to be True as long as the ground truth label is in the top K
predicated labels.
If top_k = 1, then TopKAccuracy is identical to Accuracy.
- Parameters:
- top_k (int) – Whether targets are in top k predictions.
- name (str) – Name of this metric instance for display.
update(labels, preds)
Override this method to update the state variables.
class hat.metrics.argoverse2_metric.Brier(max_guesses: int = 6, name='Brier', **kwargs)
compute()
Override this method to compute final results from metric states.
All states variables registered with self.add_state are synchronized
across devices before the execution of this method.
update(pred: Tensor, target: Tensor, prob: Tensor | None = None, valid_mask: Tensor | None = None, keep_invalid_final_step: bool = True, min_criterion: str = 'FDE', **kwargs)
Override this method to update the state variables.
class hat.metrics.argoverse2_metric.HitRate(max_guesses: int = 6, name='HitRate', **kwargs)
compute()
Override this method to compute final results from metric states.
All states variables registered with self.add_state are synchronized
across devices before the execution of this method.
update(pred: Tensor, target: Tensor, prob: Tensor | None = None, valid_mask: Tensor | None = None, keep_invalid_final_step: bool = True, miss_criterion: str = 'FDE', miss_threshold: float = 2.0, **kwargs)
Override this method to update the state variables.
class hat.metrics.argoverse2_metric.MR(max_guesses: int = 6, name='MR', **kwargs)
compute()
Override this method to compute final results from metric states.
All states variables registered with self.add_state are synchronized
across devices before the execution of this method.
update(pred: Tensor, target: Tensor, prob: Tensor | None = None, valid_mask: Tensor | None = None, keep_invalid_final_step: bool = True, miss_criterion: str = 'FDE', miss_threshold: float = 2.0, **kwargs)
Override this method to update the state variables.
class hat.metrics.argoverse2_metric.minADE(max_guesses: int = 6, name='minADE', **kwargs)
compute()
Override this method to compute final results from metric states.
All states variables registered with self.add_state are synchronized
across devices before the execution of this method.
update(pred: Tensor, target: Tensor, prob: Tensor | None = None, valid_mask: Tensor | None = None, keep_invalid_final_step: bool = True, min_criterion: str = 'FDE', **kwargs)
Override this method to update the state variables.
class hat.metrics.argoverse2_metric.minAHE(max_guesses: int = 6, name='minAHE', **kwargs)
compute()
Override this method to compute final results from metric states.
All states variables registered with self.add_state are synchronized
across devices before the execution of this method.
update(pred: Tensor, target: Tensor, prob: Tensor | None = None, valid_mask: Tensor | None = None, keep_invalid_final_step: bool = True, min_criterion: str = 'FDE', **kwargs)
Override this method to update the state variables.
class hat.metrics.argoverse2_metric.minFDE(max_guesses: int = 6, name='minFDE', **kwargs)
compute()
Override this method to compute final results from metric states.
All states variables registered with self.add_state are synchronized
across devices before the execution of this method.
update(pred: Tensor, target: Tensor, prob: Tensor | None = None, valid_mask: Tensor | None = None, keep_invalid_final_step: bool = True, **kwargs)
Override this method to update the state variables.
class hat.metrics.argoverse2_metric.minFHE(max_guesses: int = 6, name='minFHE', **kwargs)
compute()
Override this method to compute final results from metric states.
All states variables registered with self.add_state are synchronized
across devices before the execution of this method.
update(pred: Tensor, target: Tensor, prob: Tensor | None = None, valid_mask: Tensor | None = None, keep_invalid_final_step: bool = True, **kwargs)
Override this method to update the state variables.
class hat.metrics.argoverse_metric.ArgoverseMetric(name: str = 'ArgoverseMetric', max_guesses: int = 6, horizon: int = 30, miss_threshold: float = 2.0)
Evaluation Argoverse Detection.
- Parameters:
- name – Name of this metric instance for display.
- max_guesses – Number of guesses allowed.
- horizon – Prediction horizon.
- miss_threshold – Distance threshold for
the last predicted coordinate.
compute()
Override this method to compute final results from metric states.
All states variables registered with self.add_state are synchronized
across devices before the execution of this method.
get()
Get current evaluation result.
To skip the synchronization among devices, please override this method
and calculate results without calling self.compute().
- Returns:
Name of the metrics.
values: Value of the evaluations.
- Return type:
names
get_ade(forecasted_trajectory: ndarray, gt_trajectory: ndarray)
Compute Average Displacement Error.
- Parameters:
- forecasted_trajectory – Predicted trajectory with shape.
(pred_len x 2)
- gt_trajectory – Ground truth trajectory with shape.
(pred_len x 2)
- Returns:
Average Displacement Error
- Return type:
ade
get_fde(forecasted_trajectory: ndarray, gt_trajectory: ndarray)
Compute Final Displacement Error.
- Parameters:
- forecasted_trajectory – Predicted trajectory with shape.
(pred_len x 2)
- gt_trajectory – Ground truth trajectory with shape.
(pred_len x 2)
- Returns:
Final Displacement Error
- Return type:
fde
reset()
Reset the metric state variables to their default value.
If (and only if) there are state variables that are not registered
with self.add_state need to be regularly set to default values,
please extend this method in subclasses.
update(meta, preds)
Override this method to update the state variables.
class hat.metrics.coco_detection.COCODetectionMetric(ann_file: str, val_interval: int = 1, name: str = 'COCOMeanAP', save_prefix: str = './WORKSPACE/results', adas_eval_task: str | None = None, use_time: bool = True, cleanup: bool = False, warn_without_compute: bool = False)
Evaluation in COCO protocol.
- Parameters:
- ann_file – validation data annotation json file path.
- val_interval – evaluation interval.
- name – name of this metric instance for display.
- save_prefix – path to save result.
- adas_eval_task – task name for adas-eval, such as ‘vehicle’, ‘person’
and so on.
- use_time – whether to use time for name.
- cleanup – whether to clean up the saved results when the process ends.
- Raises:
RuntimeError – fail to write json to disk.
get()
Get evaluation metrics.
reset()
Reset the metric state variables to their default value.
If (and only if) there are state variables that are not registered
with self.add_state need to be regularly set to default values,
please extend this method in subclasses.
update(output: Dict)
Update internal buffer with latest predictions.
Note that the statistics are not available until
you call self.get() to return the metrics.
- Parameters:
output – A dict of model output which includes det results and
image infos.
class hat.metrics.kitti2d_detection.Kitti2DMetric(anno_file: str, name: str = 'kittiAP', is_plot: bool = True)
Kitti2D detection metric.
For details, you can refer to
http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=2d.
- Parameters:
- anno_file (str) – validation data annotation json file path.
- name – name of this metric instance for display.
- is_plot – whether to plot the PR curve.
get()
Get current evaluation result.
To skip the synchronization among devices, please override this method
and calculate results without calling self.compute().
- Returns:
Name of the metrics.
values: Value of the evaluations.
- Return type:
names
reset()
Reset the metric state variables to their default value.
If (and only if) there are state variables that are not registered
with self.add_state need to be regularly set to default values,
please extend this method in subclasses.
update(output: Dict)
- Parameters:
- output – A dict of model output which includes det results and
image infos. Support batch_size >= 1
- output**['pred_bboxes']** (List *[*torch.Tensor ]) – Network output
for each input.
- output**['img_name']** (List *(*str )) – image name for each input.
class hat.metrics.kitti3d_detection.Kitti3DMetricDet(current_classes: List[str], compute_aos: bool = False, name: str = 'kitti3dAPDet', difficultys: List | None = None)
get()
Get current evaluation result.
To skip the synchronization among devices, please override this method
and calculate results without calling self.compute().
- Returns:
Name of the metrics.
values: Value of the evaluations.
- Return type:
names
reset()
Reset the metric state variables to their default value.
If (and only if) there are state variables that are not registered
with self.add_state need to be regularly set to default values,
please extend this method in subclasses.
update(preds, labels)
Override this method to update the state variables.
class hat.metrics.loss_show.LossShow(name: str = 'loss', norm: bool = True)
Show loss.
- Parameters:
- name – Name of this metric instance for display.
- norm – Whether norm loss when loss size bigger than 1.
If True, calculate mean loss, else calculate loss sum.
Default True.
get()
Get current evaluation result.
To skip the synchronization among devices, please override this method
and calculate results without calling self.compute().
- Returns:
Name of the metrics.
values: Value of the evaluations.
- Return type:
names
reset()
Reset the metric state variables to their default value.
If (and only if) there are state variables that are not registered
with self.add_state need to be regularly set to default values,
please extend this method in subclasses.
update(loss: Tensor | Dict[str, Tensor])
Override this method to update the state variables.
class hat.metrics.mean_iou.MeanIOU(seg_class: List[str], name: str = 'MeanIOU', ignore_index: int = 255, global_ignore_index: Sequence | int = 255, verbose: bool = False)
Evaluation segmentation results.
- Parameters:
- seg_class (list *(*str )) – A list of classes the segmentation dataset
includes,the order should be the same as the label.
- name (str) – Name of this metric instance for display, also used as
monitor params for Checkpoint.
- ignore_index (int) – The label index that will be ignored in evaluation.
- global_ignore_index (list *,*int) – The label index that will be ignored in
global evaluation,such as:mIoU,mAcc,aAcc.Supporting list of label
index.
- verbose (bool) – Whether to return verbose value for aidi eval, default
is False.
compute()
Get evaluation metrics.
update(label: Tensor, preds: Sequence[Tensor] | Tensor)
Update internal buffer with latest predictions.
Note that the statistics are not available until
you call self.get() to return the metrics.
- Parameters:
- preds – model output.
- label – gt.
class hat.metrics.metric_keypoints.MeanKeypointDist(name: str = 'mean_dist', feat_stride: int = 4, decode_mode: str = 'averaged')
This metric calculates the mean distance between keypoints.
- Parameters:
- name – name of the metric
- feat_stride – Stride of the feature map with respect to the input image.
- decode_mode – Mode for decoding the predicted keypoints.
“averaged” or “diff_sign”
update(data)
Override this method to update the state variables.
class hat.metrics.metric_keypoints.PCKMetric(alpha: float, feat_stride: int, img_shape: Tuple[int], decode_mode: str = 'diff_sign')
Compute PCK (Proportion of Correct Keypoints) metric.
- Parameters:
- alpha – Parameter alpha for defining the PCK threshold as a
percentage of the object’s size.
- feat_stride – Stride of the feature map with respect to the input image.
- img_shape – Shape of the input image in (height, width) format.
- decode_mode – Mode for decoding the predicted keypoints.
“averaged” or “diff_sign”
update(data)
Override this method to update the state variables.
class hat.metrics.metric_lane_detection.CulaneF1Score(name: str = 'CulaneF1Score', iou_thresh: float = 0.5, img_shape: Tuple[int, int, int] = (590, 1640, 1), width: int = 30, samples: int = 50)
Metric for Lane detection task, using for Culanedataset.
This metric aligns with the official c++ implementation.
- Parameters:
- name – Metric name.
- iou_thresh – IOU overlap threshold for TP.
- img_shape – Image shape used when calculating iou.
- width – The width of the line.
- samples – Number of samples between two points.
compute()
Override this method to compute final results from metric states.
All states variables registered with self.add_state are synchronized
across devices before the execution of this method.
update(annos: List[List[ndarray]], preds: List[List[ndarray]])
Override this method to update the state variables.
class hat.metrics.metric_optical_flow.EndPointError(name='EPE', use_mask=False)
Metric for OpticalFlow task, endpoint error (EPE).
The endpoint error measures the distance between the
endpoints of two optical flow vectors (u0, v0) and (u1, v1)
and is defined as sqrt((u0 - u1) ** 2 + (v0 - v1) ** 2).
- Parameters:
name – metric name.
Refs:
: https://github.com/philferriere/tfoptflow/blob/master/tfoptflow/model_pwcnet.py
update(labels, preds, masks=None)
Override this method to update the state variables.
class hat.metrics.mot_metrics.MotMetric(gt_dir: str, name: str = 'MOTA', save_prefix: str = './WORKSPACE/motresults', cleanup: bool = False)
Evaluation in MOT.
- Parameters:
- gt_dir – validation data gt dir.
- name – name of this metric instance for display.
- save_prefix – path to save result.
- cleanup – whether to clean up the saved results when the process ends.
get()
Get evaluation metrics.
reset()
Reset the metric state variables to their default value.
If (and only if) there are state variables that are not registered
with self.add_state need to be regularly set to default values,
please extend this method in subclasses.
update(outputs: Dict)
Update internal buffer with latest predictions.
Note that the statistics are not available until
you call self.get() to return the metrics.
- Parameters:
output – A dict of model output which includes det results and
image infos.
class hat.metrics.nuscenes_map_metric.NuscenesMapMetric(name: str = 'NuscenesMapMetric', eval_use_same_gt_sample_num_flag=False, save_prefix: str = './WORKSPACE/results', fixed_ptsnum_per_line: int = -1, pc_range: Sequence[float] = None, classes: Sequence[str] = None, metric: str = 'chamfer', map_ann_file: str = None)
Evaluation Nuscenes Detection.
- Parameters:
- name – Name of this metric instance for display.
- eval_use_same_gt_sample_num_flag – Whether to use same gt sample number
for evaluation.
- save_prefix – Path to save result.
- fixed_ptsnum_per_line – Number of fixed points per line.
- pc_range – Range of point cloud.
- classes – Classes for evaluation.
- metric – Metric used for evaluation.
- map_ann_file – Map annotation file.
compute()
Override this method to compute final results from metric states.
All states variables registered with self.add_state are synchronized
across devices before the execution of this method.
get()
Get current evaluation result.
To skip the synchronization among devices, please override this method
and calculate results without calling self.compute().
- Returns:
Name of the metrics.
values: Value of the evaluations.
- Return type:
names
reset()
Reset the metric state variables to their default value.
If (and only if) there are state variables that are not registered
with self.add_state need to be regularly set to default values,
please extend this method in subclasses.
update(batch_data, pred_results)
Override this method to update the state variables.
class hat.metrics.nuscenes_metric.NuscenesMetric(name: str = 'NuscenesMetric', data_root: str = '', version: str = 'v1.0-mini', save_prefix: str = './WORKSPACE/results', verbose: bool = True, eval_version: str = 'detection_cvpr_2019', use_lidar: bool = False, classes: Sequence[str] = None, use_ddp: bool = True, trans_lidar_dim: bool = False, trans_lidar_rot: bool = True, meta_key='meta', lidar_key='lidar2ego')
Evaluation Nuscenes Detection.
- Parameters:
- name – Name of this metric instance for display.
- data_root – Data path of nuscenes data.
- version – Version of nuscenes data.
Choosen from [‘v1.0-mini’, ‘v1.0-tranval’].
- save_prefix – Path to save result.
- verbose – Wether output verbose log.
- eval_version – Eval version.
- use_lidar – Wheather use lidar bbox.
- classes – List of class name.
- use_ddp – Wheather use ddp to eval metric.
compute()
Override this method to compute final results from metric states.
All states variables registered with self.add_state are synchronized
across devices before the execution of this method.
get()
Get current evaluation result.
To skip the synchronization among devices, please override this method
and calculate results without calling self.compute().
- Returns:
Name of the metrics.
values: Value of the evaluations.
- Return type:
names
reset()
Reset the metric state variables to their default value.
If (and only if) there are state variables that are not registered
with self.add_state need to be regularly set to default values,
please extend this method in subclasses.
update(meta, pred_bboxes)
Override this method to update the state variables.
class hat.metrics.nuscenes_metric.NuscenesMonoMetric(nms_threshold: float = 0.05, use_cpu: bool = False, **kwargs)
Evaluation Nuscenes Detection for mono.
- Parameters:
nms_threshold – NMS threshold for detection under same sample.
get()
Get current evaluation result.
To skip the synchronization among devices, please override this method
and calculate results without calling self.compute().
- Returns:
Name of the metrics.
values: Value of the evaluations.
- Return type:
names
get_attr_name(attr_idx, label_name)
Get attribute from predicted index.
This is a workaround to predict attribute when the predicted velocity
is not reliable. We map the predicted attribute index to the one
in the attribute set. If it is consistent with the category, we will
keep it. Otherwise, we will use the default attribute.
- Parameters:
- attr_idx (int) – Attribute index.
- label_name (str) – Predicted category name.
- Returns:
Predicted attribute name.
- Return type:
str
update(metas, pred_bboxes)
Override this method to update the state variables.
class hat.metrics.voc_detection.VOC07MApMetric(num_classes: int, iou_thresh: float = 0.5, class_names: List[str] | None = None)
Mean average precision metric for PASCAL V0C 07 dataset.
- Parameters:
- num_classes – Num classs.
- iou_thresh – IOU overlap threshold for TP
- class_names – if provided, will print out AP for each class
class hat.metrics.voc_detection.VOCMApMetric(num_classes: int, iou_thresh: float | List = 0.5, class_names: List[str] | None = None, ignore_ioa_thresh: float = 0.2, score_threshs: List[float] | None = None, max_iou_thresh: float | None = None, iou_thresh_interval: float = 0.05, cls_idx_mapping: bool = False)
Calculate mean AP for object detection task.
- Parameters:
- num_classes – Num classs.
- iou_thresh – IOU overlap threshold for TP.
- class_names – If provided, will print out AP for each class.
- ignore_ioa_thresh – The IOA threshold for ignored GTs.
- score_threshs – If provided, will print recall/precision at
each score threshold.
- max_iou_thresh – If provided, will calculate average AP at each iou
threshold from ‘iou_thresh’ to ‘max_iou_thresh’, and the step
is ‘iou_thresh_interval’. Must be larger than iou_thresh.
- iou_thresh_interval – The step to generate a list of iou thresholds.
Default is 0.05. You need to make sure
max_iou_thresh - iou_thresh can be devided by this value.
compute()
Override this method to compute final results from metric states.
All states variables registered with self.add_state are synchronized
across devices before the execution of this method.
gather_metrics()
Update num_inst and sum_metric.
reset()
Clear the internal statistics to initial state.
update(model_outs: Dict)
model_outs is a dict, the meaning of it’s key is as following.
pred_bboxes(List): Each element of pred_bboxes is the predict result
: of an image. It’s shape is (N, 6), where 6 means
(x1, y1, x2, y2, label, score).
gt_bboxes(List): Each element of gt_bboxes is the bboxes’ coordinates
: of an image. It’s shape is (N, 4), where 4 means (x1, y1, x2, y2).
gt_classes(List): Each element of gt_classes is the bboxes’ classes
: of an image. It’s shape is (N).
gt_difficult(List): Each element of gt_difficult is the bboxes’
: difficult flag of an image. It’s shape is (N).