6.1.7.5. hat.metrics¶
Metrics widely used for different datasets in HAT.
6.1.7.5.1. Metrics¶
Computes accuracy classification score. |
|
# TODO(min.du, 0.5): merged with Accuracy #. |
|
Computes top k predictions accuracy. |
|
Evaluation in COCO protocol. |
|
Kitti2D detection metric. |
|
Show loss. |
|
Evaluation segmentation results. |
|
Base class for all evaluation metrics. |
|
Mean average precision metric for PASCAL V0C 07 dataset. |
|
Calculate mean AP for object detection task. |
|
Metric for OpticalFlow task, endpoint error (EPE). |
|
6.1.7.5.2. API Reference¶
- class hat.metrics.Accuracy(axis=1, name='accuracy')¶
Computes accuracy classification score.
- 参数
axis (int) – The axis that represents classes
name (str) – Name of this metric instance for display.
- update(labels, preds)¶
Override this method to update the state variables.
- class hat.metrics.AccuracySeg(name='accuracy', axis=1)¶
# TODO(min.du, 0.5): merged with Accuracy #.
- update(output)¶
Override this method to update the state variables.
- class hat.metrics.COCODetectionMetric(ann_file: str, val_interval: int = 1, name: str = 'COCOMeanAP', save_prefix: str = './WORKSPACE/results', adas_eval_task: Optional[str] = None, use_time: bool = True, cleanup: bool = False)¶
Evaluation in COCO protocol.
- 参数
ann_file – validation data annotation json file path.
val_interval – evaluation interval.
name – name of this metric instance for display.
save_prefix – path to save result.
adas_eval_task – task name for adas-eval, such as ‘vehicle’, ‘person’ and so on.
use_time – whether to use time for name.
cleanup – whether to clean up the saved results when the process ends.
- 引发
RuntimeError – fail to write json to disk.
- get()¶
Get evaluation metrics.
- reset()¶
Reset the metric state variables to their default value.
If (and only if) there are state variables that are not registered with self.add_state need to be regularly set to default values, please extend this method in subclasses.
- update(output: Dict)¶
Update internal buffer with latest predictions.
Note that the statistics are not available until you call self.get() to return the metrics.
- 参数
output – A dict of model output which includes det results and image infos.
- class hat.metrics.EndPointError(name='EPE')¶
Metric for OpticalFlow task, endpoint error (EPE).
The endpoint error measures the distance between the endpoints of two optical flow vectors (u0, v0) and (u1, v1) and is defined as sqrt((u0 - u1) ** 2 + (v0 - v1) ** 2).
- 参数
name – metric name.
- update(labels, preds)¶
Override this method to update the state variables.
- class hat.metrics.EvalMetric(name: Union[List[str], str], process_group: Optional[torch._C._distributed_c10d.ProcessGroup] = None, warn_without_compute: bool = True)¶
Base class for all evaluation metrics.
Built on top of torchmetrics.metric.Metric, this base class introduces the name attribute and a name-value format output (the get method). It also makes possible to syncnronize state tensors of different shapes in each device to support AP-like metrics.
注解
This is a base class that provides common metric interfaces. One should not use this class directly, but inherit it to create new metric classes instead.
- 参数
name – Name of this metric instance for display.
process_group – Specify the process group on which synchronization is called. Default: None (which selects the entire world)
warn_without_compute – Whether to output warning log if self.compute is not called in self.get. Since synchronization among devices is executed in self.compute, this value reflects if the metric will support distributed computation.
- compute() Union[float, List[float]] ¶
Override this method to compute final results from metric states.
All states variables registered with self.add_state are synchronized across devices before the execution of this method.
- get() Tuple[Union[str, List[str]], Union[float, List[float]]] ¶
Get current evaluation result.
To skip the synchronization among devices, please override this method and calculate results without calling self.compute().
- 返回
Name of the metrics. values: Value of the evaluations.
- 返回类型
names
- get_name_value()¶
Return zipped name and value pairs.
- 返回
A (name, value) tuple list.
- 返回类型
List(tuples)
- reset() None ¶
Reset the metric state variables to their default value.
If (and only if) there are state variables that are not registered with self.add_state need to be regularly set to default values, please extend this method in subclasses.
- abstract update(*_: Any, **__: Any) None ¶
Override this method to update the state variables.
- class hat.metrics.Kitti2DMetric(anno_file: str, name: str = 'kittiAP', is_plot: bool = True)¶
Kitti2D detection metric.
For details, you can refer to http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=2d.
- 参数
anno_file (str) – validation data annotation json file path.
name – name of this metric instance for display.
is_plot – whether to plot the PR curve.
- get()¶
Get current evaluation result.
To skip the synchronization among devices, please override this method and calculate results without calling self.compute().
- 返回
Name of the metrics. values: Value of the evaluations.
- 返回类型
names
- reset()¶
Reset the metric state variables to their default value.
If (and only if) there are state variables that are not registered with self.add_state need to be regularly set to default values, please extend this method in subclasses.
- update(output: Dict)¶
- 参数
output – A dict of model output which includes det results and image infos. Support batch_size >= 1
output['pred_bboxes'] (List[torch.Tensor]) – Network output for each input.
output['img_name'] (List(str)) – image name for each input.
- class hat.metrics.Kitti3DMetricDet(compute_aos, current_classes, name='kitti3dAPDet', difficultys=[0, 1, 2])¶
- get()¶
Get current evaluation result.
To skip the synchronization among devices, please override this method and calculate results without calling self.compute().
- 返回
Name of the metrics. values: Value of the evaluations.
- 返回类型
names
- reset()¶
Reset the metric state variables to their default value.
If (and only if) there are state variables that are not registered with self.add_state need to be regularly set to default values, please extend this method in subclasses.
- update(preds, labels)¶
Override this method to update the state variables.
- class hat.metrics.LossShow(name: str = 'loss', norm: bool = True)¶
Show loss.
# TODO(min.du, 0.1): a better class name is required #
- 参数
name – Name of this metric instance for display.
norm – Whether norm loss when loss size bigger than 1. If True, calculate mean loss, else calculate loss sum. Default True.
- get()¶
Get current evaluation result.
To skip the synchronization among devices, please override this method and calculate results without calling self.compute().
- 返回
Name of the metrics. values: Value of the evaluations.
- 返回类型
names
- reset()¶
Reset the metric state variables to their default value.
If (and only if) there are state variables that are not registered with self.add_state need to be regularly set to default values, please extend this method in subclasses.
- update(loss: Union[torch.Tensor, Dict[str, torch.Tensor]])¶
Override this method to update the state variables.
- class hat.metrics.MeanIOU(seg_class: List[str], name: str = 'MeanIOU', ignore_index: int = 255, global_ignore_index: Union[Sequence, int] = 255, verbose: bool = False)¶
Evaluation segmentation results.
- 参数
seg_class (list(str)) – A list of classes the segmentation dataset includes,the order should be the same as the label.
name (str) – Name of this metric instance for display, also used as monitor params for Checkpoint.
ignore_index (int) – The label index that will be ignored in evaluation.
global_ignore_index (list,int) – The label index that will be ignored in global evaluation,such as:mIoU,mAcc,aAcc.Supporting list of label index.
verbose (bool) – Whether to return verbose value for aidi eval, default is False.
- compute()¶
Get evaluation metrics.
- update(label: torch.Tensor, preds: Union[Sequence[torch.Tensor], torch.Tensor])¶
Update internal buffer with latest predictions.
Note that the statistics are not available until you call self.get() to return the metrics.
- 参数
preds – model output.
label – gt.
- class hat.metrics.TopKAccuracy(top_k, name='top_k_accuracy')¶
Computes top k predictions accuracy.
TopKAccuracy differs from Accuracy in that it considers the prediction to be
True
as long as the ground truth label is in the top K predicated labels.If top_k =
1
, then TopKAccuracy is identical to Accuracy.- 参数
top_k (int) – Whether targets are in top k predictions.
name (str) – Name of this metric instance for display.
- update(labels, preds)¶
Override this method to update the state variables.
- class hat.metrics.VOC07MApMetric(*args, **kwargs)¶
Mean average precision metric for PASCAL V0C 07 dataset.
- 参数
iou_thresh (float) – IOU overlap threshold for TP
class_names (List[str]) – if provided, will print out AP for each class
- class hat.metrics.VOCMApMetric(iou_thresh=0.5, class_names=None)¶
Calculate mean AP for object detection task.
- 参数
iou_thresh (float) – IOU overlap threshold for TP
class_names (List[str]) – if provided, will print out AP for each class
- get()¶
Get current evaluation result.
To skip the synchronization among devices, please override this method and calculate results without calling self.compute().
- 返回
Name of the metrics. values: Value of the evaluations.
- 返回类型
names
- reset()¶
Clear the internal statistics to initial state.
- update(model_outs)¶
Override this method to update the state variables.