1 module des.isys.neiro.layer.neiron;
2 
3 import std.math;
4 import std.algorithm;
5 import std.traits;
6 import std.conv;
7 
8 import des.isys.neiro.neiron;
9 import des.isys.neiro.func;
10 
11 abstract class BPLink(T) : WeightLink!(T,T)
12     if( isFloatingPoint!T )
13 {
14     @property T deltaWeight() const;
15 
16     void propagateError(T);
17     void correct(T,T);
18 }
19 
20 interface BPNeiron(T) : Neiron!T
21     if( isFloatingPoint!T )
22 {
23     void addError(T);
24     void correct(T,T);
25     void setLinks( BPLink!T[] lnks );
26     @property BPLink!T[] bpLinks();
27 }
28 
29 abstract class FakeBPNeiron(T) : BPNeiron!T
30     if( isFloatingPoint!T )
31 {
32     void process() {}
33     void addError(T) {}
34     void correct(T,T) {}
35     void setLinks( BPLink!T[] lnks ) {}
36     @property BPLink!T[] bpLinks() { return []; }
37 }
38 
39 class ReferenceBPNeiron(T) : FakeBPNeiron!T
40     if( isFloatingPoint!T )
41 {
42     Neiron!T neiron;
43     this( Neiron!T neiron )
44     in{ assert( neiron !is null ); } body
45     { this.neiron = neiron; }
46     @property T output() const { return neiron.output; }
47 }
48 
49 class BaseBPLink(T) : BPLink!(T)
50     if( isFloatingPoint!T )
51 {
52 protected:
53     BPNeiron!T neiron;
54 
55     override @property T source() const { return neiron.output; }
56 
57     T lw, dw;
58 
59 public:
60 
61     this( BPNeiron!T input, T w=1 )
62     {
63         neiron = input;
64         lw = w;
65         dw = 0;
66     }
67 
68     override @property
69     {
70         T weight() const { return lw; }
71 
72         void weight( T nlw )
73         {
74             lw = nlw;
75             dw = 0;
76         }
77 
78         T deltaWeight() const { return dw; }
79     }
80 
81     override void propagateError( T beta )
82     { neiron.addError( beta * weight ); }
83 
84     // k1 = alpha; k2 = (1-alpha) * nu * beta
85     override void correct( T k1, T k2 ) 
86     {
87         dw = dw * k1 + k2 * source;
88         lw += dw;
89     }
90 }
91 
92 class BaseBPNeiron(T) : BaseNeiron!T, BPNeiron!T
93     if( isFloatingPoint!T )
94 {
95 protected:
96 
97     T error = 0;
98 
99     T link_scale = 1;
100 
101     DerivativeFunction!T func;
102     BPLink!T[] bp_links;
103 
104     override T activate( T x ) { return func( x * link_scale ); }
105 
106     override @property Link!T[] links()
107     { return to!(Link!T[])(bp_links); }
108 
109 public:
110 
111     this( DerivativeFunction!T func )
112     in { assert( func !is null ); } body
113     {
114         super(0);
115         this.func = func;
116     }
117 
118     void setLinks( BPLink!T[] lnks )
119     {
120         bp_links = lnks;
121         link_scale = 1.0 / cast(T)bp_links.length;
122     }
123 
124     void addError( T err ) { error += err; }
125 
126     void correct( T nu, T alpha )
127     {
128         backpropagation( nu, alpha );
129         eliminateStagnation();
130         resetError();
131     }
132 
133     @property BPLink!T[] bpLinks() { return bp_links; }
134 
135 protected:
136     void backpropagation( T nu, T alpha )
137     {
138         T beta = func.dx( value ) * error;
139         T k2 = ( 1.0 - alpha ) * nu * beta;
140 
141         foreach( link; bp_links )
142         {
143             link.propagateError( beta * link_scale );
144             link.correct( alpha, k2 );
145         }
146     }
147 
148     void eliminateStagnation() { }
149 
150     void resetError() { error = 0; }
151 }
152 
153 unittest
154 {
155     import std.typecons;
156 
157     auto input = new ValueNeiron!float(1);
158     auto link = new BaseBPLink!float( new ReferenceBPNeiron!float(input), 0.25 );
159     auto neiron = new BaseBPNeiron!float( new LinearDependence!float(2) );
160     neiron.setLinks( [ cast(BPLink!float)link ] );
161     neiron.process();
162     assert( neiron.output == 0.5 );
163 }