surprise库使用(二)——使用自定义数据集

surprise库有一组内建 数据集,但当然可以使用自定义数据集。加载评分数据集可以从文件(例如csv文件)或pandas数据框中完成。无论哪种方式,都需要定义一个Reader对象来解析文件或数据框。\

要从文件加载数据集(例如csv文件),需要 load_from_file()方法:

from surprise import BaselineOnly
from surprise import Dataset
from surprise import Reader
from surprise.model_selection import cross_validate

file_path = os.path.expanduser('~/.surprise_data/ml-100k/ml-100k/u.data')#数据集文件所在目录

reader = Reader(line_format='user item rating timestamp', sep='\t')

data = Dataset.load_from_file(file_path, reader=reader)

cross_validate(BaselineOnly(), data, verbose=True) #现在可以使用这个数据集,例如调用cross_validate 

从pandas数据框加载数据集,需要使用 load_from_df()方法。这里不多赘述,可自己查看说明


使用交叉验证迭代器

对于交叉验证,可以用 cross_validate() 完成所有的工作。但是为了更好地控制,可以实例化交叉验证迭代器,并使用且迭代器的 split() 方法和算法的 test() 方法,对每一折进行预测。

下面是一个栗子,我们使用了一个经典的K-fold交叉验证程序,其中包含数据被分为3份(3折交叉验证):

from surprise import SVD
from surprise import Dataset
from surprise import accuracy
from surprise.model_selection import KFold

data = Dataset.load_builtin('ml-100k') #加载数据集

# define a cross-validation iterator
kf = KFold(n_splits=3) #定义交叉验证迭代器

algo = SVD()

for trainset, testset in kf.split(data):

    # 训练并测试算法
    algo.fit(trainset)
    predictions = algo.test(testset)

    # 计算并打印RMSE
    accuracy.rmse(predictions, verbose=True)

结果:

RMSE: 0.9374
RMSE: 0.9476
RMSE: 0.9478

也可以使用其他交叉验证迭代器,例如LeaveOneOut或ShuffleSplit。在这里查看所有可用的迭代器Surprise的交叉验证工具的设计灵感来源于优秀的scikit-learn API。


交叉验证的一个特例是folds已经由某些文件预定义,这里同样查看说明。


使用GridSearchCV调整算法参数

cross_validate()函数针对给定的一组交叉验证参数报告过程的准确性度量(如RMSE、MAE这些)。如果你想知道哪个参数组合能够产生最好的结果,那么这个 GridSearchCV类就可以解决问题。给定一个dict参数,这个类会尝试所有的参数组合,并报告任何准确性度量(对不同分割进行平均的)的最佳参数。它受到scikit-learn的GridSearchCV的启发

接下来这个例子我们尝试了SVD算法的参数 n_epochslr_all 和 reg_all 的不同值。

from surprise import SVD
from surprise import Dataset
from surprise.model_selection import GridSearchCV

data = Dataset.load_builtin('ml-100k')

param_grid = {'n_epochs': [5, 10], 'lr_all': [0.002, 0.005],
              'reg_all': [0.4, 0.6]}
gs = GridSearchCV(SVD, param_grid, measures=['rmse', 'mae'], cv=3)

gs.fit(data)

# best RMSE score
print(gs.best_score['rmse'])

# combination of parameters that gave the best RMSE score
print(gs.best_params['rmse'])

结果:

0.961300130118
{'n_epochs': 10, 'lr_all': 0.005, 'reg_all': 0.4}

我们在这里评估3倍交叉验证过程的平均RMSE和MAE,但可以使用任何交叉验证迭代器

一旦fit()被调用,  best_estimator 这个属性给了我们一个算法实例最优的一组参数,可以根据我们的喜好使用它:

# 可以使用产生最优RMSE的算法
algo = gs.best_estimator['rmse']
algo.fit(data.build_full_trainset())

注意:字典参数,例如bsl_optionssim_options需要特殊对待。请参阅以下使用示例:

param_grid = {'k': [10, 20],
              'sim_options': {'name': ['msd', 'cosine'],
                              'min_support': [1, 5],
                              'user_based': [False]}
              }

当然,两者可以结合使用,例如 KNNBaseline :

param_grid = {'bsl_options': {'method': ['als', 'sgd'],
                              'reg': [1, 2]},
              'k': [2, 3],
              'sim_options': {'name': ['msd', 'cosine'],
                              'min_support': [1, 5],
                              'user_based': [False]}
              }

为了进一步分析,cv_results属性具有所有需要的信息,并且可以在pandas数据框中导入:

results_df = pd.DataFrame.from_dict(gs.cv_results)

在我们的例子中,该cv_results属性看起来像这样(float格式):

'split0_test_rmse'  [ 1.0  1.0  0.97  0.98  0.98  0.99  0.96  0.97 ] 
'split1_test_rmse'  [ 1.0  1.0  0.97  0.98  0.98  0.99  0.96  0.97 ] 
'split2_test_rmse'  [ 1.0  1.0  0.97  0.98  0.98  0.99  0.96  0.97 ] 
'mean_test_rmse'    [ 1.0  1.0  0.97  0.98  0.98  0.99  0.96  0.97 ] 
'std_test_rmse'     [ 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0 ] 
'rank_test_rmse'    [ 7  8  3  5  4  6  1  2 ]
'split0_test_mae'   [ 0.81  0.82  0.78  0.79  0.79  0.8  0.77  0.79 ] 
'split1_test_mae'   [ 0.8  0.81  0.78  0.79  0.78  0.79  0.77  0.78 ] 
'split2_test_mae'   [ 0.81  0.81  0.78  0.79  0.78  0.8  0.77  0.78 ] 
'mean_test_mae'     [ 0.81  0.81  0.78  0.79  0.79  0.8  0.77  0.78 ] 
'std_test_mae'      [ 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0 ] 
'rank_test_mae'     [ 7  8  2  5  4  6  1  3 ]
'mean_fit_time'     [ 1.53  1.52  1.53  1.53  3.04  3.05  3.06  3.02 ] 
'std_fit_time'      [ 0.03  0.04  0.0  0.01  0.04  0.01  0.06  0.01 ] 
'mean_test_time'    [ 0.46  0.45  0.44  0.44  0.47  0.49  0.46  0.34 ] 
'std_test_time'     [ 0.0  0.01  0.01  0.0  0.03  0.06  0.01  0.08 ] 
'PARAMS'            [{ 'n_epochs'  5  'lr_all'  0.002  'reg_all'  0.4 }, { 'n_epochs'  5  'lr_all'  0.002  'reg_all'  0.6 }, {'n_epochs'  5  'lr_all'  0.005  'reg_all'  0.4 }, { 'n_epochs'  5  'lr_all'  0.005  'reg_all'  0.6 }, { 'n_epochs'  10  'lr_all'  0.002  'reg_all'  0.4 }, { 'n_epochs'  10  'lr_all'  0.002  'reg_all'  0。6 }, {'n_epochs'  10  'lr_all'  0.005  'reg_all'  0.4 }, { 'n_epochs'  10  'lr_all'  0.4 0.6 0.4 0.6 ]0.005, 'reg_all': 0.6}]
'param_n_epochs':   [5, 5, 5, 5, 10, 10, 10, 10]
'param_lr_all':     [0.0, 0.0, 0.01, 0.01, 0.0, 0.0, 0.01, 0.01]
'param_reg_all':    [0.4, 0.6, 0.4, 0.6,    



猜你喜欢

转载自blog.csdn.net/yuxeaotao/article/details/79852254