sklearn kMeans 分类实战,对沪深300的每日涨跌进行分类

栏目: 数据库 · 发布时间: 7年前

# ohlc_clustering.py

import copy
import datetime
import pymysql

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# from matplotlib.finance import candlestick_ohlc
import matplotlib.dates as mdates
from matplotlib.dates import (
    DateFormatter, WeekdayLocator, DayLocator, MONDAY
)
import mpl_finance as mpf
import numpy as np
import pandas as pd
import pandas_datareader.data as web
from sklearn.cluster import KMeans

def get_open_normalised_prices():
    """
    Obtains a pandas DataFrame containing open normalised prices
    for high, low and close for a particular equities symbol
    from Yahoo Finance. That is, it creates High/Open, Low/Open
    and Close/Open columns.
    """
    # df = web.DataReader(symbol, "yahoo", start, end)

    connect = pymysql.connect(
        host='127.0.0.1',
        db='blog',
        user='root',
        passwd='123456',
        charset='utf8',
        use_unicode=True
    )
    select_sql_300 = "select date as Date,open as Open,high as High,low as Low,adj_close as Close from `tmp_stock` where code ='399300' and date >= '2004-6-01'  order by date asc"
    df = pd.read_sql(select_sql_300, con=connect)

    df["H/O"] = df["High"]/df["Open"]
    df["L/O"] = df["Low"]/df["Open"]
    df["C/O"] = df["Close"]/df["Open"]
    df.drop(
        [
            "Open", "High", "Low",
            "Close", "Date"
        ],
        axis=1, inplace=True
    )
    return df

def plot_candlesticks(data):
    """
    Plot a candlestick chart of the prices,
    appropriately formatted for dates
    """
    # Copy and reset the index of the dataframe
    # to only use a subset of the data for plotting
    df = copy.deepcopy(data)
    # df = df[df.index >= since]
    df.reset_index(inplace=True)
    df['date_fmt'] = df['Date'].apply(
        lambda date: mdates.date2num(date.to_pydatetime())
    )

    # Set the axis formatting correctly for dates
    # with Mondays highlighted as a "major" tick
    mondays = WeekdayLocator(MONDAY)
    alldays = DayLocator()
    weekFormatter = DateFormatter('%b %d')
    fig, ax = plt.subplots(figsize=(16,4))
    fig.subplots_adjust(bottom=0.2)
    # ax.xaxis.set_major_locator(mondays)
    # ax.xaxis.set_minor_locator(alldays)
    # ax.xaxis.set_major_formatter(weekFormatter)

    # Plot the candlestick OHLC chart using black for
    # up days and red for down days
    csticks = mpf.candlestick_ohlc(
        ax, df[
            ['date_fmt', 'Open', 'High', 'Low', 'Close']
        ].values, width=0.6,
        colorup='r', colordown='green'
    )
    # ax.set_axis_bgcolor((1,1,0.9))
    ax.xaxis_date()
    # plt.setp(
    #     plt.gca().get_xticklabels(),
    #     rotation=45, horizontalalignment='right'
    # )
    plt.show()


def plot_cluster(data):
    df = copy.deepcopy(data)
    # df = df[df.index >= since]
    df.reset_index(inplace=True)
    df['date_fmt'] = df['Date'].apply(
        lambda date: mdates.date2num(date.to_pydatetime())
    )

    # Set the axis formatting correctly for dates
    # with Mondays highlighted as a "major" tick
    mondays = WeekdayLocator(MONDAY)
    alldays = DayLocator()
    weekFormatter = DateFormatter('%b %d')
    fig, ax = plt.subplots(figsize=(16, 4))
    fig.subplots_adjust(bottom=0.2)
    # ax.xaxis.set_major_locator(mondays)
    # ax.xaxis.set_minor_locator(alldays)
    # ax.xaxis.set_major_formatter(weekFormatter)

    df0 = df.loc[df["Cluster"] == 0]
    df1 = df.loc[df["Cluster"] == 1]
    df2 = df.loc[df["Cluster"] == 2]
    df3 = df.loc[df["Cluster"] == 3]

    size = 1.2
    ax.scatter(df0['date_fmt'], df0['Close'], s=size, c='y',marker='o',label="Small Rise")
    ax.scatter(df1['date_fmt'], df1['Close'], s=size, c='g', marker='o', label="Big Down")
    ax.scatter(df2['date_fmt'], df2['Close'], s=size, c='r', marker='o', label="Big Rise")
    ax.scatter(df3['date_fmt'], df3['Close'], s=size, c='b', marker='o', label="Small Down")

    ax.xaxis_date()
    plt.xlabel('Date')
    plt.ylabel('Close')
    plt.legend(loc='upper right')

    # plt.setp(
    #     plt.gca().get_xticklabels(),
    #     rotation=45, horizontalalignment='right'
    # )
    plt.show()

def plot_3d_normalised_candles(data):
    """
    Plot a 3D scatterchart of the open-normalised bars
    highlighting the separate clusters by colour
    """
    fig = plt.figure(figsize=(12, 9))
    ax = Axes3D(fig, elev=21, azim=-136)
    ax.scatter(
        data["H/O"], data["L/O"], data["C/O"],
        c=labels.astype(np.float)
    )
    ax.set_xlabel('High/Open')
    ax.set_ylabel('Low/Open')
    ax.set_zlabel('Close/Open')
    plt.show()

def plot_cluster_ordered_candles(data):
    """
    Plot a candlestick chart ordered by cluster membership
    with the dotted blue line representing each cluster
    boundary.
    """
    # Set the format for the axis to account for dates
    # correctly, particularly Monday as a major tick
    mondays = WeekdayLocator(MONDAY)
    alldays = DayLocator()
    weekFormatter = DateFormatter("")
    fig, ax = plt.subplots(figsize=(16,4))
    ax.xaxis.set_major_locator(mondays)
    ax.xaxis.set_minor_locator(alldays)
    ax.xaxis.set_major_formatter(weekFormatter)

    # Sort the data by the cluster values and obtain
    # a separate DataFrame listing the index values at
    # which the cluster boundaries change
    df = copy.deepcopy(data)
    df.sort_values(by="Cluster", inplace=True)
    df.reset_index(inplace=True)
    df["clust_index"] = df.index
    df["clust_change"] = df["Cluster"].diff()
    change_indices = df[df["clust_change"] != 0]

    # Plot the OHLC chart with cluster-ordered "candles"
    csticks = mpf.candlestick_ohlc(
        ax, df[
            ["clust_index", 'Open', 'High', 'Low', 'Close']
        ].values, width=0.6,
        colorup='#000000', colordown='#ff0000'
    )
    # ax.set_axis_bgcolor((1,1,0.9))

    # Add each of the cluster boundaries as a blue dotted line
    for row in change_indices.iterrows():
        plt.axvline(
            row[1]["clust_index"],
            linestyle="dashed", c="blue"
        )
    plt.xlim(0, len(df))
    plt.setp(
        plt.gca().get_xticklabels(),
        rotation=45, horizontalalignment='right'
    )
    plt.show()

def create_follow_cluster_matrix(data):
    """
    Creates a k x k matrix, where k is the number of clusters
    that shows when cluster j follows cluster i.
    """
    data["ClusterTomorrow"] = data["Cluster"].shift(-1)
    data.dropna(inplace=True)
    data["ClusterTomorrow"] = data["ClusterTomorrow"].apply(int)
    hs300["ClusterMatrix"] = list(zip(data["Cluster"], data["ClusterTomorrow"]))
    cmvc = data["ClusterMatrix"].value_counts()
    clust_mat = np.zeros( (k, k) )
    for row in cmvc.iteritems():
        clust_mat[row[0]] = row[1]*100.0/len(data)
    print("Cluster Follow-on Matrix:")
    print(clust_mat)


if __name__ == "__main__":
    # Obtain S&P500 pricing data from Yahoo Finance

    connect = pymysql.connect(
        host='127.0.0.1',
        db='blog',
        user='root',
        passwd='123456',
        charset='utf8',
        use_unicode=True
    )
    select_sql_300 = "select date as Date,open as Open,high as High,low as Low,adj_close as Close from `tmp_stock` where code ='399300' and date >= '2004-6-01'  order by date asc"
    hs300 = pd.read_sql(select_sql_300, con=connect)


    # # Plot last year of price "candles"
    plot_candlesticks(hs300)

    # Carry out K-Means clustering with four clusters on the
    # three-dimensional data H/O, L/O and C/O
    hs300_norm = get_open_normalised_prices()
    k = 4
    km = KMeans(n_clusters=k, random_state=42)
    km.fit(hs300_norm)
    labels = km.labels_
    hs300_norm["Cluster"] = labels
    hs300["Cluster"] = labels


    #
    # # Plot the 3D normalised candles using H/O, L/O, C/O
    plot_3d_normalised_candles(hs300_norm)


    # Create and output the cluster follow-on matrix
    create_follow_cluster_matrix(hs300)

    plot_cluster(hs300)

sklearn kMeans 分类实战,对沪深300的每日涨跌进行分类

sklearn kMeans 分类实战,对沪深300的每日涨跌进行分类

sklearn kMeans 分类实战,对沪深300的每日涨跌进行分类

http://www.waitingfy.com/archives/5039

参考:

https://zhuanlan.zhihu.com/p/43872533

https://www.quantstart.com/articles/k-means-clustering-of-daily-ohlc-bar-data

Post Views: 0

5039


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Introduction to Computation and Programming Using Python

Introduction to Computation and Programming Using Python

John V. Guttag / The MIT Press / 2013-7 / USD 25.00

This book introduces students with little or no prior programming experience to the art of computational problem solving using Python and various Python libraries, including PyLab. It provides student......一起来看看 《Introduction to Computation and Programming Using Python》 这本书的介绍吧!

Markdown 在线编辑器
Markdown 在线编辑器

Markdown 在线编辑器

RGB CMYK 转换工具
RGB CMYK 转换工具

RGB CMYK 互转工具

HSV CMYK 转换工具
HSV CMYK 转换工具

HSV CMYK互换工具