Linuxkiss
    首页 Linux C/C++ C++面试 Qt答疑 Qml中文手册 Qt CMake Python 工具
Linuxkiss
www.linuxkiss.com 你可以精通一门IT技术
  1. 首页
  2. openCv
  3. 正文

Python3 实现DNN功能【详】

2018年04月07日 15点热度 0人点赞

Python3 实现DNN功能,供初学者参考一下。

代码展示:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# construct simple DNN
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt

import sys


def generate_data():
    x = np.linspace(-2, 2, 100)[np.newaxis, :]
    noise = np.random.normal(0.0, 0.5, size=(1, 100))
    y = x ** 2 + noise
    return x, y


class DNN():
    def __init__(self, input_nodes=1, hidden1_nodes=4, hidden2_nodes=4, output_nodes=1):
        self.input_nodes = input_nodes
        self.hidden1_nodes = hidden1_nodes
        self.hidden2_nodes = hidden2_nodes
        self.output_nodes = output_nodes
        self.build_DNN()

    def build_DNN(self):
        np.random.seed(1)
        # Layer1 parameter
        self.w1 = np.random.normal(0.0, 0.1, size=(self.hidden1_nodes, self.input_nodes))
        self.b1 = np.zeros(shape=(self.hidden1_nodes, 1))
        # Layer2 parameter
        self.w2 = np.random.normal(0.0, 0.2, size=(self.hidden2_nodes, self.hidden1_nodes))
        self.b2 = np.ones(shape=(self.hidden2_nodes, 1))
        # Layer3 parameter
        self.w3 = np.random.normal(0.0, 0.5, size=(self.output_nodes, self.hidden2_nodes))
        self.b3 = np.zeros(shape=(self.output_nodes, 1))

    def forwardPropagation(self, inputs):
        self.z1 = np.matmul(self.w1, inputs) + self.b1
        self.a1 = 1 / (1 + np.exp(-self.z1))
        self.z2 = np.matmul(self.w2, self.a1) + self.b2
        self.a2 = 1 / (1 + np.exp(-self.z2))
        self.z3 = np.matmul(self.w3, self.a2) + self.b3
        self.a3 = self.z3

    def backwardPropagation(self, da, a, a_1, w, b, last=False):
        '''
        da:current layer activation output partial devirate result
        a:current layer activation output
        a_1:previous layer of current layer activation output
        w:current parameter
        b:current bias
        '''
        # dz = da/dz
        if last:
            dz = da
        else:
            dz = a * (1 - a) * da
        # dw = dz/dw
        nums = da.shape[1]
        dw = np.matmul(dz, a_1.T) / nums
        db = np.mean(dz, axis=1, keepdims=True)
        # da_1 = dz/da_1
        da_1 = np.matmul(w.T, dz)

        w -= 0.5 * dw
        b -= 0.5 * db
        return da_1

    def train(self, x, y, max_iter=50000):
        for i in range(max_iter):
            self.forwardPropagation(x)
            # print(self.a3)
            loss = 0.5 * np.mean((self.a3 - y) ** 2)
            da = self.a3 - y
            da_2 = self.backwardPropagation(da, self.a3, self.a2, self.w3, self.b3, True)
            da_1 = self.backwardPropagation(da_2, self.a2, self.a1, self.w2, self.b2)
            da_0 = self.backwardPropagation(da_1, self.a1, x, self.w1, self.b1)
            self.view_bar(i + 1, max_iter, loss)
        return self.a3

    def view_bar(self, step, total, loss):
        rate = step / total
        rate_num = int(rate * 40)
        r = '\rstep-%d loss value-%.4f[%s%s]\t%d%% %d/%d' % \
            (step, loss, '>' * rate_num, '-' * (40 - rate_num),int(rate * 100), step, total)
        
        sys.stdout.write(r)
        sys.stdout.flush()


if __name__ == '__main__':
    x, y = generate_data()
    plt.scatter(x, y, c='r')
    plt.ion()

    dnn = DNN()
    predict = dnn.train(x, y)
    print('plot')
    plt.plot(x.flatten(), predict.flatten(), '-')
    plt.show()

执行结果:

标签: openCv
最后更新:2020年06月17日

Leo

保持饥渴的专注,追求最佳的品质

点赞
< 上一篇
下一篇 >
关注公众号

日历
2021年4月
一 二 三 四 五 六 日
« 2月    
 1234
567891011
12131415161718
19202122232425
2627282930  
最新 热点 随机
最新 热点 随机
windows中出现"无法解析的外部符号"到底是什么原因 Qt5中lambda表达式用法,非常实用 warning: class 'InterFace' defines a non-default destructor but does not define a copy constructor, a copy assignment operator, a move constructor or a move assignment operator 无法解析的外部符号 "public: static struct QMetaObject const Windows下Qt代码出现的错误总结 QT Creator如何在创建项目的时候,头文件和cpp文件的首字母默认大写
7.1.1 用户配置文件–用户信息文件 qmake语言描述--变量(Variables) “野指针”是什么以及它的作用 4.1.3 文件处理命令: 1touch 2cat 3tac 4more 5less 6head 7tail QML对象类型(QML Object Types) strcpy库函数的实现细节
标签聚合
C++ C/C++面试 openCv Linux qmake qml中文文档 Qt qml中文手册

COPYRIGHT © 2020 Linuxkiss. ALL RIGHTS RESERVED.