Custom metrics in Keras and how simple they are to use in tensorflow2.2

栏目: IT技术 · 发布时间: 4年前

So lets get down to it. We first make a custom metric class. While there are more steps to this and they are show in the referenced jupyter notebook , the important thing is to implement the API that integrates with the rest of Keras training and testing workflow. That is as simple as implementing and update_state that takes in the true labels and predictions, a reset_states that re-initializes the metric.

class ConfusionMatrixMetric(tf.keras.metrics.Metric):


    def update_state(self, y_true, y_pred,sample_weight=None):
        self.total_cm.assign_add(self.confusion_matrix(y_true,y_pred))
        return self.total_cm

    def result(self):
        return self.process_confusion_matrix()

    def confusion_matrix(self,y_true, y_pred):
        """
        Make a confusion matrix
        """
        y_pred=tf.argmax(y_pred,1)
        cm=tf.math.confusion_matrix(y_true,y_pred,dtype=tf.float32,num_classes=self.num_classes)
        return cm

    def process_confusion_matrix(self):
        "returns precision, recall and f1 along with overall accuracy"
        cm=self.total_cm
        diag_part=tf.linalg.diag_part(cm)
        precision=diag_part/(tf.reduce_sum(cm,0)+tf.constant(1e-15))
        recall=diag_part/(tf.reduce_sum(cm,1)+tf.constant(1e-15))
        f1=2*precision*recall/(precision+recall+tf.constant(1e-15))
        return precision,recall,f1

In the normal Keras workflow, the method result will be called and it will return a number and nothing else needs to be done. However, in our case we have three tensors for precision, recall and f1 being returned and Keras does not know how to handle this out of the box. This is where the new features of tensorflow 2.2 come in.

Request for deletion

About

MC.AI – Aggregated news about artificial intelligence

MC.AI collects interesting articles and news about artificial intelligence and related areas. The contributions come from various open sources and are presented here in a collected form.

The copyrights are held by the original authors, the source is indicated with each contribution.

Contributions which should be deleted from this platform can be reported using the appropriate form (within the contribution).

MC.AI is open for direct submissions, we look forward to your contribution!

Search on MC.AI

mc.ai aggregates articles from different sources - copyright remains at original authors


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

数据结构与算法分析(C++版)(第3版)

数据结构与算法分析(C++版)(第3版)

Clifford A. Shaffer / 张铭、刘晓丹、等译 / 电子工业出版社 / 2013 / 59.00元

本书采用当前流行的面向对象的C++程序设计语言来描述数据结构和算法, 因为C++语言是程序员最广泛使用的语言。因此, 程序员可以把本书中的许多算法直接应用于将来的实际项目中。尽管数据结构和算法在设计本质上还是很底层的东西, 并不像大型软件工程项目开发那样, 对面向对象方法具有直接的依赖性, 因此有人会认为并不需要采用高层次的面向对象技术来描述底层算法。 但是采用C++语言能更好地体现抽象数据类型的......一起来看看 《数据结构与算法分析(C++版)(第3版)》 这本书的介绍吧!

MD5 加密
MD5 加密

MD5 加密工具

XML 在线格式化
XML 在线格式化

在线 XML 格式化压缩工具

Markdown 在线编辑器
Markdown 在线编辑器

Markdown 在线编辑器