forked from sony791210/VerificationCode
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_captcha.py
More file actions
41 lines (39 loc) · 1.97 KB
/
train_captcha.py
File metadata and controls
41 lines (39 loc) · 1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#!/usr/bin/python
import tensorflow as tf
import numpy as np
import string
import captcha_model
if __name__ == '__main__':
captcha = captcha_model.captchaModel()
width,height,char_num,characters,classes = captcha.get_parameter()
x = tf.placeholder(tf.float32, [None, height,width,1])
y_ = tf.placeholder(tf.float32, [None, char_num*classes])
keep_prob = tf.placeholder(tf.float32)
model = captcha_model.captchaModel(width,height,char_num,classes)
y_conv = model.create_model(x,keep_prob)
cross_entropy = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_,logits=y_conv))
train_step = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cross_entropy)
predict = tf.reshape(y_conv, [-1,char_num, classes])
real = tf.reshape(y_,[-1,char_num, classes])
correct_prediction = tf.equal(tf.argmax(predict,2), tf.argmax(real,2))
correct_prediction = tf.cast(correct_prediction, tf.float32)
accuracy = tf.reduce_mean(correct_prediction)
saver = tf.train.Saver()
config = tf.ConfigProto()
#config.gpu_options.per_process_gpu_memory_fraction = 0.5
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
step = 0
while True:
batch_x,batch_y = next(captcha.gen_captcha(32))
_,loss = sess.run([train_step,cross_entropy],feed_dict={x: batch_x, y_: batch_y, keep_prob: 0.75})
#print ('step:%d,loss:%f' % (step,loss))
if step % 100 == 0:
batch_x_test,batch_y_test = next(captcha.gen_captcha(256))
acc = sess.run(accuracy, feed_dict={x: batch_x_test, y_: batch_y_test, keep_prob: 1.})
print ('###############################################step:%d,accuracy:%f' % (step,acc))
if acc > 0.99:
saver.save(sess,"./capcha_model.ckpt")
break
step += 1