RPA 量化交易系统 代码-Py语言版-立哥开发

Python 3.7.5 (tags/v3.7.5:5c02a39a0b, Oct 15 2019, 00:11:34) [MSC v.1916 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license()" for more information.
>>> # installed via pip
#Copy right  2020  Jacky  Zong . All  rihgts  reserved .
import numpy as np
import matplotlib.pyplot as plt
from sklearn.externals import joblib
from pyriemann.classification import MDM
from pyriemann.utils.distance import distance_riemann 
from tqdm import tqdm
from collections import OrderedDict 

# get the functions from RPA package
import rpa.transfer_learning as TL
import rpa.diffusion_map as DM
import rpa.get_dataset as GD

# get the dataset
datafolder = '../datasets/'
paradigm = 'motorimagery'
target = GD.get_dataset(datafolder, paradigm, subject=1) 
source = GD.get_dataset(datafolder, paradigm, subject=2)

# instantiate the Riemannian classifier to use
clf = MDM() 

# create a scores dictionary
methods_list = ['org', 'rct', 'rpa', 'clb']
scores = OrderedDict()
for method in methods_list:
    scores[method] = []

nrzt = 5
for _ in tqdm(range(nrzt)):

    # get the split for the source and target dataset
    source_org, target_org_train, target_org_test = TL.get_sourcetarget_split(source, target, ncovs_train=20)

    # get the score with the original dataset
    scores['org'].append(TL.get_score_transferlearning(clf, source_org, target_org_train, target_org_test))

    # get the score with the re-centered matrices
    source_rct, target_rct_train, target_rct_test = TL.RPA_recenter(source_org, target_org_train, target_org_test)
    scores['rct'].append(TL.get_score_transferlearning(clf, source_rct, target_rct_train, target_rct_test))

    # rotate the re-centered-stretched matrices using information from classes
    source_rpa, target_rpa_train, target_rpa_test = TL.RPA_rotate(source_rct, target_rct_train, target_rct_test)
    scores['rpa'].append(TL.get_score_transferlearning(clf, source_rpa, target_rpa_train, target_rpa_test))

    # get the score without any transformation
    scores['clb'].append(TL.get_score_notransfer(clf, target_org_train, target_org_test))
    
for method in methods_list:
    scores[method] = np.mean(scores[method])

for meth in ['org', 'rct', 'rpa', 'clb']:
    print(meth, '{0:.2f}'.format(scores[meth]))   Type "help", "copyright", "credits" or "license()" for more information.Type "help", "copyright", "credits" or "license()" for more information.

猜你喜欢

转载自blog.csdn.net/weixin_45806384/article/details/108007908