Quantcast
Channel: OpenCV Q&A Forum - RSS feed
Viewing all articles
Browse latest Browse all 41027

SVM TrainAuto test_samples

$
0
0
Hi, I was looking for updating to 3.0 and checked the new TrainData class and how it works with the SVM TrainAuto. As far as I can see during cross validation, the temp_test_samples matrix is never filled with data or am I getting anything wrong? Mat **temp_test_samples**(test_sample_count, var_count, CV_32F); predict(**temp_test_samples**, temp_test_responses, 0); > int test_sample_count = (sample_count + k_fold/2)/k_fold; int train_sample_count = sample_count - test_sample_count; SvmParams best_params = params; double min_error = FLT_MAX; int rtype = responses.type(); Mat temp_train_samples(train_sample_count, var_count, CV_32F); Mat temp_test_samples(test_sample_count, var_count, CV_32F); Mat temp_train_responses(train_sample_count, 1, rtype); Mat temp_test_responses; // If grid.minVal == grid.maxVal, this will allow one and only one pass through the loop with params.var = grid.minVal. #define FOR_IN_GRID(var, grid) \ for( params.var = grid.minVal; params.var == grid.minVal || params.var < grid.maxVal; params.var = (grid.minVal == grid.maxVal) ? grid.maxVal + 1 : params.var * grid.logStep ) FOR_IN_GRID(C, C_grid) FOR_IN_GRID(gamma, gamma_grid) FOR_IN_GRID(p, p_grid) FOR_IN_GRID(nu, nu_grid) FOR_IN_GRID(coef0, coef_grid) FOR_IN_GRID(degree, degree_grid) { // make sure we updated the kernel and other parameters setParams(params); double error = 0; for( k = 0; k < k_fold; k++ ) { int start = (k*sample_count + k_fold/2)/k_fold; for( i = 0; i < train_sample_count; i++ ) { j = sidx[(i+start)%sample_count]; memcpy(temp_train_samples.ptr(i), samples.ptr(j), sample_size); if( is_classification ) temp_train_responses.at(i) = responses.at(j); else if( !responses.empty() ) temp_train_responses.at(i) = responses.at(j); } // Train SVM on samples if( !do_train( temp_train_samples, temp_train_responses )) continue; for( i = 0; i < train_sample_count; i++ ) { j = sidx[(i+start+train_sample_count) % sample_count]; memcpy(temp_train_samples.ptr(i), samples.ptr(j), sample_size); } predict(temp_test_samples, temp_test_responses, 0); for( i = 0; i < test_sample_count; i++ ) { float val = temp_test_responses.at(i); j = sidx[(i+start+train_sample_count) % sample_count]; if( is_classification ) error += (float)(val != responses.at(j)); else { val -= responses.at(j); error += val*val; } } } if( min_error > error ) { min_error = error; best_params = params; } } params = best_params; return do_train( samples, responses ); Best regards, chris

Viewing all articles
Browse latest Browse all 41027

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>