sklearn kMeans 分类实战,对沪深300的每日涨跌进行分类

# ohlc_clustering.py

import copy
import datetime
import pymysql

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# from matplotlib.finance import candlestick_ohlc
import matplotlib.dates as mdates
from matplotlib.dates import (
    DateFormatter, WeekdayLocator, DayLocator, MONDAY
)
import mpl_finance as mpf
import numpy as np
import pandas as pd
import pandas_datareader.data as web
from sklearn.cluster import KMeans

def get_open_normalised_prices():
    """
    Obtains a pandas DataFrame containing open normalised prices
    for high, low and close for a particular equities symbol
    from Yahoo Finance. That is, it creates High/Open, Low/Open
    and Close/Open columns.
    """
    # df = web.DataReader(symbol, "yahoo", start, end)

    connect = pymysql.connect(
        host='127.0.0.1',
        db='blog',
        user='root',
        passwd='123456',
        charset='utf8',
        use_unicode=True
    )
    select_sql_300 = "select date as Date,open as Open,high as High,low as Low,adj_close as Close from `tmp_stock` where code ='399300' and date >= '2004-6-01'  order by date asc"
    df = pd.read_sql(select_sql_300, con=connect)

    df["H/O"] = df["High"]/df["Open"]
    df["L/O"] = df["Low"]/df["Open"]
    df["C/O"] = df["Close"]/df["Open"]
    df.drop(
        [
            "Open", "High", "Low",
            "Close", "Date"
        ],
        axis=1, inplace=True
    )
    return df

def plot_candlesticks(data):
    """
    Plot a candlestick chart of the prices,
    appropriately formatted for dates
    """
    # Copy and reset the index of the dataframe
    # to only use a subset of the data for plotting
    df = copy.deepcopy(data)
    # df = df[df.index >= since]
    df.reset_index(inplace=True)
    df['date_fmt'] = df['Date'].apply(
        lambda date: mdates.date2num(date.to_pydatetime())
    )

    # Set the axis formatting correctly for dates
    # with Mondays highlighted as a "major" tick
    mondays = WeekdayLocator(MONDAY)
    alldays = DayLocator()
    weekFormatter = DateFormatter('%b %d')
    fig, ax = plt.subplots(figsize=(16,4))
    fig.subplots_adjust(bottom=0.2)
    # ax.xaxis.set_major_locator(mondays)
    # ax.xaxis.set_minor_locator(alldays)
    # ax.xaxis.set_major_formatter(weekFormatter)

    # Plot the candlestick OHLC chart using black for
    # up days and red for down days
    csticks = mpf.candlestick_ohlc(
        ax, df[
            ['date_fmt', 'Open', 'High', 'Low', 'Close']
        ].values, width=0.6,
        colorup='r', colordown='green'
    )
    # ax.set_axis_bgcolor((1,1,0.9))
    ax.xaxis_date()
    # plt.setp(
    #     plt.gca().get_xticklabels(),
    #     rotation=45, horizontalalignment='right'
    # )
    plt.show()


def plot_cluster(data):
    df = copy.deepcopy(data)
    # df = df[df.index >= since]
    df.reset_index(inplace=True)
    df['date_fmt'] = df['Date'].apply(
        lambda date: mdates.date2num(date.to_pydatetime())
    )

    # Set the axis formatting correctly for dates
    # with Mondays highlighted as a "major" tick
    mondays = WeekdayLocator(MONDAY)
    alldays = DayLocator()
    weekFormatter = DateFormatter('%b %d')
    fig, ax = plt.subplots(figsize=(16, 4))
    fig.subplots_adjust(bottom=0.2)
    # ax.xaxis.set_major_locator(mondays)
    # ax.xaxis.set_minor_locator(alldays)
    # ax.xaxis.set_major_formatter(weekFormatter)

    df0 = df.loc[df["Cluster"] == 0]
    df1 = df.loc[df["Cluster"] == 1]
    df2 = df.loc[df["Cluster"] == 2]
    df3 = df.loc[df["Cluster"] == 3]

    size = 1.2
    ax.scatter(df0['date_fmt'], df0['Close'], s=size, c='y',marker='o',label="Small Rise")
    ax.scatter(df1['date_fmt'], df1['Close'], s=size, c='g', marker='o', label="Big Down")
    ax.scatter(df2['date_fmt'], df2['Close'], s=size, c='r', marker='o', label="Big Rise")
    ax.scatter(df3['date_fmt'], df3['Close'], s=size, c='b', marker='o', label="Small Down")

    ax.xaxis_date()
    plt.xlabel('Date')
    plt.ylabel('Close')
    plt.legend(loc='upper right')

    # plt.setp(
    #     plt.gca().get_xticklabels(),
    #     rotation=45, horizontalalignment='right'
    # )
    plt.show()

def plot_3d_normalised_candles(data):
    """
    Plot a 3D scatterchart of the open-normalised bars
    highlighting the separate clusters by colour
    """
    fig = plt.figure(figsize=(12, 9))
    ax = Axes3D(fig, elev=21, azim=-136)
    ax.scatter(
        data["H/O"], data["L/O"], data["C/O"],
        c=labels.astype(np.float)
    )
    ax.set_xlabel('High/Open')
    ax.set_ylabel('Low/Open')
    ax.set_zlabel('Close/Open')
    plt.show()

def plot_cluster_ordered_candles(data):
    """
    Plot a candlestick chart ordered by cluster membership
    with the dotted blue line representing each cluster
    boundary.
    """
    # Set the format for the axis to account for dates
    # correctly, particularly Monday as a major tick
    mondays = WeekdayLocator(MONDAY)
    alldays = DayLocator()
    weekFormatter = DateFormatter("")
    fig, ax = plt.subplots(figsize=(16,4))
    ax.xaxis.set_major_locator(mondays)
    ax.xaxis.set_minor_locator(alldays)
    ax.xaxis.set_major_formatter(weekFormatter)

    # Sort the data by the cluster values and obtain
    # a separate DataFrame listing the index values at
    # which the cluster boundaries change
    df = copy.deepcopy(data)
    df.sort_values(by="Cluster", inplace=True)
    df.reset_index(inplace=True)
    df["clust_index"] = df.index
    df["clust_change"] = df["Cluster"].diff()
    change_indices = df[df["clust_change"] != 0]

    # Plot the OHLC chart with cluster-ordered "candles"
    csticks = mpf.candlestick_ohlc(
        ax, df[
            ["clust_index", 'Open', 'High', 'Low', 'Close']
        ].values, width=0.6,
        colorup='#000000', colordown='#ff0000'
    )
    # ax.set_axis_bgcolor((1,1,0.9))

    # Add each of the cluster boundaries as a blue dotted line
    for row in change_indices.iterrows():
        plt.axvline(
            row[1]["clust_index"],
            linestyle="dashed", c="blue"
        )
    plt.xlim(0, len(df))
    plt.setp(
        plt.gca().get_xticklabels(),
        rotation=45, horizontalalignment='right'
    )
    plt.show()

def create_follow_cluster_matrix(data):
    """
    Creates a k x k matrix, where k is the number of clusters
    that shows when cluster j follows cluster i.
    """
    data["ClusterTomorrow"] = data["Cluster"].shift(-1)
    data.dropna(inplace=True)
    data["ClusterTomorrow"] = data["ClusterTomorrow"].apply(int)
    hs300["ClusterMatrix"] = list(zip(data["Cluster"], data["ClusterTomorrow"]))
    cmvc = data["ClusterMatrix"].value_counts()
    clust_mat = np.zeros( (k, k) )
    for row in cmvc.iteritems():
        clust_mat[row[0]] = row[1]*100.0/len(data)
    print("Cluster Follow-on Matrix:")
    print(clust_mat)


if __name__ == "__main__":
    # Obtain S&P500 pricing data from Yahoo Finance

    connect = pymysql.connect(
        host='127.0.0.1',
        db='blog',
        user='root',
        passwd='123456',
        charset='utf8',
        use_unicode=True
    )
    select_sql_300 = "select date as Date,open as Open,high as High,low as Low,adj_close as Close from `tmp_stock` where code ='399300' and date >= '2004-6-01'  order by date asc"
    hs300 = pd.read_sql(select_sql_300, con=connect)


    # # Plot last year of price "candles"
    plot_candlesticks(hs300)

    # Carry out K-Means clustering with four clusters on the
    # three-dimensional data H/O, L/O and C/O
    hs300_norm = get_open_normalised_prices()
    k = 4
    km = KMeans(n_clusters=k, random_state=42)
    km.fit(hs300_norm)
    labels = km.labels_
    hs300_norm["Cluster"] = labels
    hs300["Cluster"] = labels


    #
    # # Plot the 3D normalised candles using H/O, L/O, C/O
    plot_3d_normalised_candles(hs300_norm)


    # Create and output the cluster follow-on matrix
    create_follow_cluster_matrix(hs300)

    plot_cluster(hs300)

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
http://www.waitingfy.com/archives/5039
参考:
https://zhuanlan.zhihu.com/p/43872533
https://www.quantstart.com/articles/k-means-clustering-of-daily-ohlc-bar-data

猜你喜欢

转载自blog.csdn.net/fox64194167/article/details/82904709