python - 使用TensorFlow創(chuàng)建邏輯回歸模型訓(xùn)練結(jié)果為nan
問(wèn)題描述
在TensorFlow中,我想創(chuàng)建一個(gè)邏輯回歸模型,代價(jià)函數(shù)如下:
使用的數(shù)據(jù)集截圖如下:
我的代碼如下:
train_X = train_data[:, :-1]train_y = train_data[:, -1:]feature_num = len(train_X[0])sample_num = len(train_X)print('Size of train_X: {}x{}'.format(sample_num, feature_num))print('Size of train_y: {}x{}'.format(len(train_y), len(train_y[0])))X = tf.placeholder(tf.float32)y = tf.placeholder(tf.float32)W = tf.Variable(tf.zeros([feature_num, 1]))b = tf.Variable([-.3])db = tf.matmul(X, tf.reshape(W, [-1, 1])) + bhyp = tf.sigmoid(db)cost0 = y * tf.log(hyp)cost1 = (1 - y) * tf.log(1 - hyp)cost = (cost0 + cost1) / -sample_numloss = tf.reduce_sum(cost)optimizer = tf.train.GradientDescentOptimizer(0.1)train = optimizer.minimize(loss)init = tf.global_variables_initializer()sess = tf.Session()sess.run(init)print(0, sess.run(W).flatten(), sess.run(b).flatten())sess.run(train, {X: train_X, y: train_y})print(1, sess.run(W).flatten(), sess.run(b).flatten())sess.run(train, {X: train_X, y: train_y})print(2, sess.run(W).flatten(), sess.run(b).flatten())
運(yùn)行結(jié)果截圖如下:
可以看到,在迭代兩次之后,得到的W和b都變成了nan,請(qǐng)問(wèn)是哪里的問(wèn)題?
問(wèn)題解答
回答1:經(jīng)過(guò)一番搜索,找到了問(wèn)題所在。
在選取迭代方式的那一句:
optimizer = tf.train.GradientDescentOptimizer(0.1)
這里0.1的學(xué)習(xí)率過(guò)大,導(dǎo)致不知什么原因在損失函數(shù)中出現(xiàn)了log(0)的情況,結(jié)果導(dǎo)致了損失函數(shù)的值為nan,解決方法是減小學(xué)習(xí)率,比如降到1e-5或者1e-6就可以正常訓(xùn)練了,我根據(jù)自己的情況把學(xué)習(xí)率調(diào)整為了1e-3,程序完美運(yùn)行。
附上最終擬合結(jié)果:
相關(guān)文章:
1. php laravel框架模型作用域2. css - input間的間距和文字上下居中3. javascript - Angular2中,組件傳參問(wèn)題4. node.js - 關(guān)于mongoose方法的回調(diào)函數(shù)的參數(shù)問(wèn)題,如何知道參數(shù)個(gè)數(shù)以及參數(shù)代表什么含義呢?5. css3 - css背景圖片高度百分百,寬度保持比例怎么做?6. android - 百度地圖加載大量marker點(diǎn)有沒(méi)有比較好的解決方案7. html表單如何讓他真正的提交出去8. svg動(dòng)畫和css3動(dòng)畫優(yōu)劣?9. 雙引號(hào)里面的值可以變量嗎10. html - 刷新網(wǎng)頁(yè)后重寫url,去掉錨點(diǎn)鏈接。
