55
66class ConvLayer (NeuralLayer ):
77
8- def __init__ (self , input_size , k , f = 3 , s = 1 , p = 1 , u_type = 'adam' , a_type = 'relu' ):
8+ def __init__ (self , input_size , k , f = 3 , s = 1 , p = 1 , u_type = 'adam' , a_type = 'relu' , dropout = 1 ):
99 self .image_size = 0
1010 self .w = input_size [2 ]
1111 self .h = input_size [1 ]
@@ -20,7 +20,7 @@ def __init__(self, input_size, k, f=3, s=1, p=1, u_type='adam', a_type='relu'):
2020 self .h2 = int ((self .h - self .f + 2 * self .p ) / self .s + 1 )
2121 self .d2 = k
2222
23- super (ConvLayer , self ).__init__ (f * f * self .d , k , u_type = u_type , a_type = a_type )
23+ super (ConvLayer , self ).__init__ (f * f * self .d , k , u_type = u_type , a_type = a_type , dropout = dropout )
2424
2525 def predict (self , batch ):
2626 self .image_size = batch .shape [0 ]
@@ -55,7 +55,8 @@ def forward(self, batch):
5555 l2 += n .regularization ()
5656
5757 sum_weights = np .array (sum_weights )
58- strength = (sum_weights .dot (cols ) + np .array (bias ).reshape (sum_weights .shape [0 ], 1 )).reshape (self .k , self .h2 , self .w2 , - 1 ).transpose (3 , 0 , 1 , 2 )
58+ strength = (sum_weights .dot (cols ) + np .array (bias ).reshape (sum_weights .shape [0 ], 1 ))
59+ strength = strength .reshape (self .k , self .h2 , self .w2 , - 1 ).transpose (3 , 0 , 1 , 2 )
5960
6061 if self .activation :
6162 if self .a_type == 'sigmoid' :
@@ -71,23 +72,32 @@ def backward(self, d, need_d=True):
7172 if d .ndim < 4 :
7273 d = d .reshape (self .w2 , self .h2 , self .k , - 1 ).T
7374
74- delta = d * u .relu_d (self .forward_result )
75- padding = ((self .w - 1 ) * self .s + self .f - self .w2 ) // 2
76- cols = u .im2col_indices (delta , self .f , self .f , padding = padding , stride = self .s )
75+ if self .activation :
76+ if self .a_type == 'sigmoid' :
77+ delta = d * u .sigmoid_d (self .forward_result )
78+ else :
79+ delta = d * u .relu_d (self .forward_result )
80+ else :
81+ delta = d
82+
7783 sum_weights = []
7884 for index , n in enumerate (self .neurons ):
79- n .delta = delta [:, index , :, :].flatten ()
85+ n .delta = delta [:, index , :, :].transpose ( 1 , 2 , 0 ). flatten ()
8086 if need_d :
81- rot = np .rot90 (n .weights .reshape (self .d , self .f * self .f ), 2 ). reshape ( self . d , self . f , self . f )[:: - 1 ]
87+ rot = np .rot90 (n .weights .reshape (self .d , self .f , self .f ), k = 2 , axes = ( 1 , 2 ))
8288 sum_weights .append (rot )
8389
8490 if not need_d :
8591 return
8692
87- sum_weights = np .array (sum_weights ).transpose (1 ,0 ,2 ,3 ).reshape (self .d , - 1 )
93+ padding = ((self .w - 1 ) * self .s + self .f - self .w2 ) // 2
94+ cols = u .im2col_indices (delta , self .f , self .f , padding = padding , stride = self .s )
95+
96+ sum_weights = np .array (sum_weights ).transpose (1 , 0 , 2 , 3 ).reshape (self .d , - 1 )
8897
8998 result = sum_weights .dot (cols )
9099 im = result .reshape (self .d , self .h , self .w , - 1 ).transpose (3 , 0 , 1 , 2 )
100+
91101 return im
92102
93103 def output_size (self ):
0 commit comments