Я пишу алгоритм обучения персептрона на смоделированных данных. Однако программа работает в бесконечном цикле, и вес имеет тенденцию быть очень большим. Что делать, чтобы отлаживать мою программу? Если вы можете указать, что пойдет не так, это также будет оценено.
То, что я здесь делаю, сначала генерирует некоторые данные в случайном порядке и присваивает им метку в соответствии с линейной целевой функцией. Затем используйте изучение персептрона, чтобы изучить эту линейную функцию. Ниже приведены помеченные данные, если я использую 100 выборок.
Кроме того, это упражнение 1.4 на книге Learning from Data.
import numpy as np
a = 1
b = 1
def target(x):
if x[1]>a*x[0]+b:
return 1
else:
return -1
def gen_y(X_sim):
return np.array([target(x) for x in X_sim])
def pcp(X,y):
w = np.zeros(2)
Z = np.hstack((X,np.array([y]).T))
while ~all(z[2]*np.dot(w,z[:2])>0 for z in Z): # some training sample is missclassified
i = np.where(y*np.dot(w,x)<0 for x in X)[0][0] # update the weight based on misclassified sample
print(i)
w = w + y[i]*X[i]
return w
if __name__ == '__main__':
X = np.random.multivariate_normal([1,1],np.diag([1,1]),20)
y = gen_y(X)
w = pcp(X,y)
print(w)
w
я намерен до бесконечности.
[-1.66580705 1.86672845]
[-3.3316141 3.73345691]
[-4.99742115 5.60018536]
[-6.6632282 7.46691382]
[-8.32903525 9.33364227]
[ -9.99484231 11.20037073]
[-11.66064936 13.06709918]
[-13.32645641 14.93382763]
[-14.99226346 16.80055609]
[-16.65807051 18.66728454]
[-18.32387756 20.534013 ]
[-19.98968461 22.40074145]
[-21.65549166 24.26746991]
[-23.32129871 26.13419836]
[-24.98710576 28.00092682]
[-26.65291282 29.86765527]
[-28.31871987 31.73438372]
[-29.98452692 33.60111218]
[-31.65033397 35.46784063]
[-33.31614102 37.33456909]
[-34.98194807 39.20129754]
[-36.64775512 41.068026 ]
В учебнике говорится:
Вопрос здесь:
Помимо вопроса: я действительно не понимаю, почему это правило обновления будет работать. Есть ли хорошая геометрическая интуиция, как это работает? Понятно, что книга ничего не дала. Правило обновления просто w(t+1)=w(t)+y(t)x(t)
где x,y
не классифицируется, т.е. y!=sign(w^T*x)
Следуя одному из ответов,
import numpy as np
np.random.seed(0)
a = 1
b = 1
def target(x):
if x[1]>a*x[0]+b:
return 1
else:
return -1
def gen_y(X_sim):
return np.array([target(x) for x in X_sim])
def pcp(X,y):
w = np.ones(3)
Z = np.hstack((np.array([np.ones(len(X))]).T,X,np.array([y]).T))
while not all(z[3]*np.dot(w,z[:3])>0 for z in Z): # some training sample is missclassified
print([z[3]*np.dot(w,z[:3])>0 for z in Z])
print(not all(z[3]*np.dot(w,z[:3])>0 for z in Z))
i = np.where(z[3]*np.dot(w,z[:3])<0 for z in Z)[0][0] # update the weight based on misclassified sample
w = w + Z[i,3]*Z[i,:3]
print([z[3]*np.dot(w,z[:3])>0 for z in Z])
print(not all(z[3]*np.dot(w,z[:3])>0 for z in Z))
print(i,w)
return w
if __name__ == '__main__':
X = np.random.multivariate_normal([1,1],np.diag([1,1]),20)
y = gen_y(X)
# import matplotlib.pyplot as plt
# plt.scatter(X[:,0],X[:,1],c=y)
# plt.scatter(X[1,0],X[1,1],c='red')
# plt.show()
w = pcp(X,y)
print(w)
Это все еще не работает и печатает
[False, True, False, False, False, True, False, False, False, False, True, False, False, False, False, False, False, False, False, False]
True
[True, False, True, True, True, False, True, True, True, True, True, True, True, True, True, True, False, True, True, True]
True
0 [ 0. -1.76405235 -0.40015721]
[True, False, True, True, True, False, True, True, True, True, True, True, True, True, True, True, False, True, True, True]
True
[True, False, True, True, True, False, True, True, True, True, True, True, True, True, True, True, False, True, True, True]
True
0 [-1. -4.52810469 -1.80031442]
[True, False, True, True, True, False, True, True, True, True, True, True, True, True, True, True, False, True, True, True]
True
[True, False, True, True, True, False, True, True, True, True, True, True, True, True, True, True, False, True, True, True]
True
0 [-2. -7.29215704 -3.20047163]
[True, False, True, True, True, False, True, True, True, True, True, True, True, True, True, True, False, True, True, True]
True
[True, False, True, True, True, False, True, True, True, True, True, True, True, True, True, True, False, True, True, True]
True
0 [ -3. -10.05620938 -4.60062883]
[True, False, True, True, True, False, True, True, True, True, True, True, True, True, True, True, False, True, True, True]
True
Кажется, что 1. только три +1
являются ложными, это указано ниже в графике. 2. Индекс возвращаемого помещения аналогично Matlab find
является неправильным.