У меня есть script, который произвольно генерирует набор данных и обучает несколько классификаторов для сравнения их друг с другом (он очень похож на http://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html):
from itertools import product
import numpy as np
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.naive_bayes import GaussianNB, MultinomialNB
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
from sklearn.datasets import make_classification
from sklearn.preprocessing import StandardScaler
from sklearn.cross_validation import train_test_split
names = ["Linear SVM", "Decision Tree",
"Random Forest", "AdaBoost", "Naive Bayes", "Linear Discriminant Analysis",
"Quadratic Discriminant Analysis"]
def griddy_mcsearchface(num_samples, num_feats, num_feats_to_remove):
classifiers = [
SVC(kernel="linear", C=0.025),
DecisionTreeClassifier(max_depth=5),
RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1),
AdaBoostClassifier(), GaussianNB(),
LinearDiscriminantAnalysis(),
QuadraticDiscriminantAnalysis()]
classifiers2 = [
SVC(kernel="linear", C=0.025),
DecisionTreeClassifier(max_depth=5),
RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1),
AdaBoostClassifier(), GaussianNB(),
LinearDiscriminantAnalysis(),
QuadraticDiscriminantAnalysis()]
X, y = make_classification(n_samples=num_samples, n_features=num_feats, n_redundant=0, n_informative=2,
random_state=1, n_clusters_per_class=1)
X = StandardScaler().fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.2)
for name, clf, clf2 in zip(names, classifiers, classifiers2):
clf.fit(X_train, y_train)
score = clf.score(X_test, y_test)
# Remove 40% of the features.
clf2.fit(X_train[:,:-num_feats_to_remove], y_train)
score2 = clf2.score(X_test[:,:-num_feats_to_remove], y_test)
yield (num_samples, num_feats, num_feats_to_remove, name, score, score2)
И для запуска:
_samples = [100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000, 100000]
_feats = [10, 20, 50, 100, 200, 500, 10000]
_feats_to_rm = [5, 10, 25, 50, 100, 250]
for num_samples, num_feats, num_feats_to_remove in product(_samples, _feats, _feats_to_rm):
if num_feats <= num_feats_to_remove:
continue
for i in griddy_mcsearchface(num_samples, num_feats, num_feats_to_remove):
print (i)
script выводит что-то вроде:
(100, 10, 5, 'Linear SVM', 1.0, 0.40000000000000002)
(100, 10, 5, 'Decision Tree', 1.0, 0.65000000000000002)
(100, 10, 5, 'Random Forest', 1.0, 0.90000000000000002)
(100, 10, 5, 'AdaBoost', 1.0, 0.65000000000000002)
(100, 10, 5, 'Naive Bayes', 1.0, 0.75)
(100, 10, 5, 'Linear Discriminant Analysis', 1.0, 0.40000000000000002)
(100, 10, 5, 'Quadratic Discriminant Analysis', 1.0, 0.84999999999999998)
(100, 20, 5, 'Linear SVM', 1.0, 1.0)
(100, 20, 5, 'Decision Tree', 0.94999999999999996, 0.94999999999999996)
(100, 20, 5, 'Random Forest', 0.80000000000000004, 0.75)
(100, 20, 5, 'AdaBoost', 1.0, 0.94999999999999996)
(100, 20, 5, 'Naive Bayes', 1.0, 1.0)
(100, 20, 5, 'Linear Discriminant Analysis', 1.0, 1.0)
(100, 20, 5, 'Quadratic Discriminant Analysis', 0.84999999999999998, 0.94999999999999996)
(100, 20, 10, 'Linear SVM', 0.94999999999999996, 0.65000000000000002)
(100, 20, 10, 'Decision Tree', 0.94999999999999996, 0.59999999999999998)
(100, 20, 10, 'Random Forest', 0.75, 0.69999999999999996)
(100, 20, 10, 'AdaBoost', 0.94999999999999996, 0.69999999999999996)
(100, 20, 10, 'Naive Bayes', 0.94999999999999996, 0.75)
но clf.fit()
теперь однопоточен.
Предполагая, что у меня достаточно потоков для запуска всех классификаторов для каждой итерации, Как я мог бы обучать классификаторы с использованием разных потоков для каждой итерации for num_samples, num_feats, num_feats_to_remove in product(_samples, _feats, _feats_to_rm)
?
И если я ограничен 4 или 8 потоками, но мне нужно обучить > 4 или > 8 классификаторов для каждой итерации, как это сделать?