PythonZigZag算法函数未返回预期结果

概述

我试图在财务数据上使用这个 Python 之字形烛台指标(利用高、低、收盘值),但下面的代码似乎有一个错误。

是否有另一个提供此功能的可用 Python 模块?

什么是锯齿形指标

每当价格反转的百分比大于预先选择的变量时,Zig Zag 指标就会在图表上绘制点。

来源

我试过什么

在为烛台图表搜索 Python 之字形指标时,我能找到的唯一代码来自此拉取请求。

def peak_valley_pivots_candlestick(close, high, low, up_thresh, down_thresh):
    """
    Finds the peaks and valleys of a series of HLC (open is not necessary).
    TR: This is modified peak_valley_pivots function in order to find peaks and valleys for OHLC.
    Parameters
    ----------
    close : This is series with closes prices.
    high : This is series with highs  prices.
    low : This is series with lows prices.
    up_thresh : The minimum relative change necessary to define a peak.
    down_thesh : The minimum relative change necessary to define a valley.
    Returns
    -------
    an array with 0 indicating no pivot and -1 and 1 indicating valley and peak
    respectively
    Using Pandas
    ------------
    For the most part, close, high and low may be a pandas series. However, the index must
    either be [0,n) or a DateTimeIndex. Why? This function does X[t] to access
    each element where t is in [0,n).
    The First and Last Elements
    ---------------------------
    The first and last elements are guaranteed to be annotated as peak or
    valley even if the segments formed do not have the necessary relative
    changes. This is a tradeoff between technical correctness and the
    propensity to make mistakes in data analysis. The possible mistake is
    ignoring data outside the fully realized segments, which may bias analysis.
    """
    if down_thresh > 0:
        raise ValueError('The down_thresh must be negative.')

    initial_pivot = _identify_initial_pivot(close, up_thresh, down_thresh)

    t_n = len(close)
    pivots = np.zeros(t_n, dtype='i1')
    pivots[0] = initial_pivot

    # Adding one to the relative change thresholds saves operations. Instead
    # of computing relative change at each point as x_j / x_i - 1, it is
    # computed as x_j / x_1. Then, this value is compared to the threshold + 1.
    # This saves (t_n - 1) subtractions.
    up_thresh += 1
    down_thresh += 1

    trend = -initial_pivot
    last_pivot_t = 0
    last_pivot_x = close[0]
    for t in range(1, len(close)):

        if trend == -1:
            x = low[t]
            r = x / last_pivot_x
            if r >= up_thresh:
                pivots[last_pivot_t] = trend
                trend = 1
                last_pivot_x = x
                last_pivot_t = t
            elif x < last_pivot_x:
                last_pivot_x = x
                last_pivot_t = t
        else:
            x = high[t]
            r = x / last_pivot_x
            if r <= down_thresh:
                pivots[last_pivot_t] = trend
                trend = -1
                last_pivot_x = x
                last_pivot_t = t
            elif x > last_pivot_x:
                last_pivot_x = x
                last_pivot_t = t


    if last_pivot_t == t_n-1:
        pivots[last_pivot_t] = trend
    elif pivots[t_n-1] == 0:
        pivots[t_n-1] = trend

    return pivots 

它可以按如下方式使用:

pivots = peak_valley_pivots_candlestick(df.Close, df.High, df.Low ,.01,-.01)

peak_valley_pivots_candlestick函数几乎按预期工作,但对于以下数据,枢轴的计算方式似乎存在错误。

数据

下面的数据是来自完整数据集的一个切片。

dict1 = {'Date': {77: '2018-12-19',
  78: '2018-12-20',
  79: '2018-12-21',
  80: '2018-12-24',
  81: '2018-12-25',
  82: '2018-12-26',
  83: '2018-12-27',
  84: '2018-12-28',
  85: '2018-12-31',
  86: '2019-01-01',
  87: '2019-01-02',
  88: '2019-01-03',
  89: '2019-01-04',
  90: '2019-01-07',
  91: '2019-01-08',
  92: '2019-01-09',
  93: '2019-01-10',
  94: '2019-01-11',
  95: '2019-01-14',
  96: '2019-01-15',
  97: '2019-01-16',
  98: '2019-01-17',
  99: '2019-01-18',
  100: '2019-01-21',
  101: '2019-01-22',
  102: '2019-01-23',
  103: '2019-01-24',
  104: '2019-01-25',
  105: '2019-01-28',
  106: '2019-01-29',
  107: '2019-01-30',
  108: '2019-01-31',
  109: '2019-02-01',
  110: '2019-02-04',
  111: '2019-02-05'},
 'Open': {77: 1.2654544115066528,
  78: 1.2625147104263306,
  79: 1.266993522644043,
  80: 1.2650061845779421,
  81: 1.2712942361831665,
  82: 1.2689388990402222,
  83: 1.2648460865020752,
  84: 1.264606237411499,
  85: 1.2689228057861328,
  86: 1.275022268295288,
  87: 1.2752337455749512,
  88: 1.2518777847290041,
  89: 1.2628973722457886,
  90: 1.2732852697372437,
  91: 1.2786905765533447,
  92: 1.2738852500915527,
  93: 1.2799508571624756,
  94: 1.275835633277893,
  95: 1.2849836349487305,
  96: 1.2876144647598269,
  97: 1.287282943725586,
  98: 1.2884771823883057,
  99: 1.298296570777893,
  100: 1.2853471040725708,
  101: 1.2892745733261108,
  102: 1.2956725358963013,
  103: 1.308318257331848,
  104: 1.3112174272537231,
  105: 1.3207770586013794,
  106: 1.3159972429275513,
  107: 1.308061599731445,
  108: 1.311681866645813,
  109: 1.3109252452850342,
  110: 1.3078563213348389,
  111: 1.3030844926834106},
 'High': {77: 1.267909288406372,
  78: 1.2705351114273071,
  79: 1.269728422164917,
  80: 1.273658275604248,
  81: 1.277791976928711,
  82: 1.2719732522964478,
  83: 1.2671220302581787,
  84: 1.2700024843215942,
  85: 1.2813942432403564,
  86: 1.2756729125976562,
  87: 1.2773349285125732,
  88: 1.2638230323791504,
  89: 1.2739664316177368,
  90: 1.2787723541259766,
  91: 1.2792304754257202,
  92: 1.2802950143814087,
  93: 1.2801146507263184,
  94: 1.2837464809417725,
  95: 1.292774677276611,
  96: 1.2916558980941772,
  97: 1.2895737886428833,
  98: 1.2939958572387695,
  99: 1.299376368522644,
  100: 1.2910722494125366,
  101: 1.296714186668396,
  102: 1.3080273866653442,
  103: 1.3095861673355105,
  104: 1.3176618814468384,
  105: 1.3210039138793943,
  106: 1.3196616172790527,
  107: 1.311991572380066,
  108: 1.3160665035247805,
  109: 1.311475396156311,
  110: 1.3098777532577517,
  111: 1.3051422834396362},
 'Low': {77: 1.2608431577682495,
  78: 1.2615113258361816,
  79: 1.2633600234985352,
  80: 1.2636953592300415,
  81: 1.266784906387329,
  82: 1.266512155532837,
  83: 1.261877417564392,
  84: 1.2636473178863523,
  85: 1.268182635307312,
  86: 1.2714558839797974,
  87: 1.2584631443023682,
  88: 1.2518777847290041,
  89: 1.261781930923462,
  90: 1.2724264860153198,
  91: 1.2714881896972656,
  92: 1.271779179573059,
  93: 1.273058295249939,
  94: 1.2716660499572754,
  95: 1.2821005582809448,
  96: 1.2756240367889404,
  97: 1.2827255725860596,
  98: 1.2836146354675293,
  99: 1.2892080545425415,
  100: 1.2831699848175049,
  101: 1.2855949401855469,
  102: 1.2945822477340698,
  103: 1.301371693611145,
  104: 1.3063528537750244,
  105: 1.313870549201965,
  106: 1.313145875930786,
  107: 1.3058068752288818,
  108: 1.3101180791854858,
  109: 1.3045804500579834,
  110: 1.3042230606079102,
  111: 1.2929919958114624},
 'Close': {77: 1.2655024528503418,
  78: 1.262785792350769,
  79: 1.2669775485992432,
  80: 1.2648941278457642,
  81: 1.2710840702056885,
  82: 1.2688745260238647,
  83: 1.2648781538009644,
  84: 1.2646220922470093,
  85: 1.269357681274414,
  86: 1.2738043069839478,
  87: 1.2754288911819458,
  88: 1.2521913051605225,
  89: 1.2628813982009888,
  90: 1.2734960317611694,
  91: 1.278608798980713,
  92: 1.2737879753112793,
  93: 1.279967188835144,
  94: 1.2753963470458984,
  95: 1.2849836349487305,
  96: 1.2874983549118042,
  97: 1.2872166633605957,
  98: 1.28857684135437,
  99: 1.2983977794647217,
  100: 1.2853471040725708,
  101: 1.2891747951507568,
  102: 1.295773148536682,
  103: 1.308215618133545,
  104: 1.3121638298034668,
  105: 1.3208470344543457,
  106: 1.3160146474838257,
  107: 1.30804443359375,
  108: 1.3117163181304932,
  109: 1.3109424114227295,
  110: 1.3077365159988403,
  111: 1.3031013011932373},
 'Pivots': {77: 0,
  78: 0,
  79: 0,
  80: 0,
  81: 0,
  82: 0,
  83: 0,
  84: 0,
  85: 1,
  86: 0,
  87: 0,
  88: 0,
  89: -1,
  90: 0,
  91: 0,
  92: 0,
  93: 0,
  94: 0,
  95: 0,
  96: 0,
  97: 0,
  98: 0,
  99: 0,
  100: 0,
  101: 0,
  102: 0,
  103: 0,
  104: 0,
  105: 1,
  106: 0,
  107: 0,
  108: 0,
  109: 0,
  110: 0,
  111: 0},
 'Pivot Price': {77: nan,
  78: nan,
  79: nan,
  80: nan,
  81: nan,
  82: nan,
  83: nan,
  84: nan,
  85: 1.2813942432403564,
  86: nan,
  87: nan,
  88: nan,
  89: 1.261781930923462,
  90: nan,
  91: nan,
  92: nan,
  93: nan,
  94: nan,
  95: nan,
  96: nan,
  97: nan,
  98: nan,
  99: nan,
  100: nan,
  101: nan,
  102: nan,
  103: nan,
  104: nan,
  105: 1.3210039138793943,
  106: nan,
  107: nan,
  108: nan,
  109: nan,
  110: nan,
  111: nan}}

显示问题的图表

2019-01-03 应该是低支点不是 2019-01-04

在图表中显示问题的代码:

import numpy as np
import plotly.graph_objects as go
import pandas as pd
from datetime import datetime

df = pd.DataFrame(dict1)

fig = go.Figure(data=[go.Candlestick(x=df['Date'],

                open=df['Open'],
                high=df['High'],
                low=df['Low'],
                close=df['Close'])])


df_diff = df['Pivot Price'].dropna().diff().copy()


fig.add_trace(
    go.Scatter(mode = "lines+markers",
        x=df['Date'],
        y=df["Pivot Price"]
    ))

fig.update_layout(
    autosize=False,
    width=1000,
    height=800,)

fig.add_trace(go.Scatter(x=df['Date'], y=df['Pivot Price'].interpolate(),
                         mode = 'lines',
                         line = dict(color='black')))


def annot(value):
    if np.isnan(value):
        return ''
    else:
        return value
    

j = 0
for i, p in enumerate(df['Pivot Price']):
    if not np.isnan(p):

        
        fig.add_annotation(dict(font=dict(color='rgba(0,0,200,0.8)',size=12),
                                        x=df['Date'].iloc[i],
                                        y=p,
                                        showarrow=False,
                                        text=annot(round(abs(df_diff.iloc[j]),3)),
                                        textangle=0,
                                        xanchor='right',
                                        xref="x",
                                        yref="y"))
        j = j + 1
fig.update_xaxes(type='category')
fig.show()

通常,该功能的工作原理如下图所示。

编辑。这是我用来创建PivotsPivot Pricecols的代码。根据@ands 的评论更新

df['Pivots'] = pivots df.loc[df['Pivots'] == 1, 'Pivot Price'] = df.High df.loc[df['Pivots'] == -1, 'Pivot Price'] = df.Low

回答

有一个小问题Pivot Price的列df,您的数据集for_so.csv已经包含列Pivot Price,所以你需要删除值,df['Pivot Price']并将其设置为基于新的价值观pivots

我用下面的代码来创建正确的'Pivots''Pivot Price'列:

pivots = peak_valley_pivots_candlestick(df.Close, df.High, df.Low ,.01,-.01)
df['Pivots'] = pivots
df['Pivot Price'] = np.nan  # This line clears old pivot prices
df.loc[df['Pivots'] == 1, 'Pivot Price'] = df.High
df.loc[df['Pivots'] == -1, 'Pivot Price'] = df.Low

主要问题是锯齿形代码。功能peak_valley_pivots_candlestick有两个小错误。在 for 循环中,条件if r >= up_thresh:为真时last_pivot_x设置为x,但应设置为high[t]

if r >= up_thresh:
    pivots[last_pivot_t] = trend#
    trend = 1
    #last_pivot_x = x
    last_pivot_x = high[t]
    last_pivot_t = t

与条件if r <= down_thresh:中的代码相同,where last_pivot_xshould be 设置为low[t]而不是x

if r <= down_thresh:
    pivots[last_pivot_t] = trend
    trend = -1
    #last_pivot_x = x
    last_pivot_x = low[t]
    last_pivot_t = t

这是完整的代码:

import numpy as np
import plotly.graph_objects as go
import pandas as pd


PEAK, VALLEY = 1, -1

def _identify_initial_pivot(X, up_thresh, down_thresh):
    """Quickly identify the X[0] as a peak or valley."""
    x_0 = X[0]
    max_x = x_0
    max_t = 0
    min_x = x_0
    min_t = 0
    up_thresh += 1
    down_thresh += 1

    for t in range(1, len(X)):
        x_t = X[t]

        if x_t / min_x >= up_thresh:
            return VALLEY if min_t == 0 else PEAK

        if x_t / max_x <= down_thresh:
            return PEAK if max_t == 0 else VALLEY

        if x_t > max_x:
            max_x = x_t
            max_t = t

        if x_t < min_x:
            min_x = x_t
            min_t = t

    t_n = len(X)-1
    return VALLEY if x_0 < X[t_n] else PEAK

def peak_valley_pivots_candlestick(close, high, low, up_thresh, down_thresh):
    """
    Finds the peaks and valleys of a series of HLC (open is not necessary).
    TR: This is modified peak_valley_pivots function in order to find peaks and valleys for OHLC.
    Parameters
    ----------
    close : This is series with closes prices.
    high : This is series with highs  prices.
    low : This is series with lows prices.
    up_thresh : The minimum relative change necessary to define a peak.
    down_thesh : The minimum relative change necessary to define a valley.
    Returns
    -------
    an array with 0 indicating no pivot and -1 and 1 indicating valley and peak
    respectively
    Using Pandas
    ------------
    For the most part, close, high and low may be a pandas series. However, the index must
    either be [0,n) or a DateTimeIndex. Why? This function does X[t] to access
    each element where t is in [0,n).
    The First and Last Elements
    ---------------------------
    The first and last elements are guaranteed to be annotated as peak or
    valley even if the segments formed do not have the necessary relative
    changes. This is a tradeoff between technical correctness and the
    propensity to make mistakes in data analysis. The possible mistake is
    ignoring data outside the fully realized segments, which may bias analysis.
    """
    if down_thresh > 0:
        raise ValueError('The down_thresh must be negative.')

    initial_pivot = _identify_initial_pivot(close, up_thresh, down_thresh)

    t_n = len(close)
    pivots = np.zeros(t_n, dtype='i1')
    pivots[0] = initial_pivot

    # Adding one to the relative change thresholds saves operations. Instead
    # of computing relative change at each point as x_j / x_i - 1, it is
    # computed as x_j / x_1. Then, this value is compared to the threshold + 1.
    # This saves (t_n - 1) subtractions.
    up_thresh += 1
    down_thresh += 1

    trend = -initial_pivot
    last_pivot_t = 0
    last_pivot_x = close[0]
    for t in range(1, len(close)):

        if trend == -1:
            x = low[t]
            r = x / last_pivot_x
            if r >= up_thresh:
                pivots[last_pivot_t] = trend#
                trend = 1
                #last_pivot_x = x
                last_pivot_x = high[t]
                last_pivot_t = t
            elif x < last_pivot_x:
                last_pivot_x = x
                last_pivot_t = t
        else:
            x = high[t]
            r = x / last_pivot_x
            if r <= down_thresh:
                pivots[last_pivot_t] = trend
                trend = -1
                #last_pivot_x = x
                last_pivot_x = low[t]
                last_pivot_t = t
            elif x > last_pivot_x:
                last_pivot_x = x
                last_pivot_t = t


    if last_pivot_t == t_n-1:
        pivots[last_pivot_t] = trend
    elif pivots[t_n-1] == 0:
        pivots[t_n-1] = trend

    return pivots


df = pd.read_csv('for_so.csv')



pivots = peak_valley_pivots_candlestick(df.Close, df.High, df.Low ,.01,-.01)
df['Pivots'] = pivots
df['Pivot Price'] = np.nan  # This line clears old pivot prices
df.loc[df['Pivots'] == 1, 'Pivot Price'] = df.High
df.loc[df['Pivots'] == -1, 'Pivot Price'] = df.Low



fig = go.Figure(data=[go.Candlestick(x=df['Date'],

                open=df['Open'],
                high=df['High'],
                low=df['Low'],
                close=df['Close'])])


df_diff = df['Pivot Price'].dropna().diff().copy()


fig.add_trace(
    go.Scatter(mode = "lines+markers",
        x=df['Date'],
        y=df["Pivot Price"]
    ))

fig.update_layout(
    autosize=False,
    width=1000,
    height=800,)

fig.add_trace(go.Scatter(x=df['Date'],
                         y=df['Pivot Price'].interpolate(),
                         mode = 'lines',
                         line = dict(color='black')))


def annot(value):
    if np.isnan(value):
        return ''
    else:
        return value
    

j = 0
for i, p in enumerate(df['Pivot Price']):
    if not np.isnan(p):

        
        fig.add_annotation(dict(font=dict(color='rgba(0,0,200,0.8)',size=12),
                                        x=df['Date'].iloc[i],
                                        y=p,
                                        showarrow=False,
                                        text=annot(round(abs(df_diff.iloc[j]),3)),
                                        textangle=0,
                                        xanchor='right',
                                        xref="x",
                                        yref="y"))
        j = j + 1

        
fig.update_xaxes(type='category')
fig.show()

上面的代码生成了这个图表:


以上是PythonZigZag算法函数未返回预期结果的全部内容。
THE END
分享
二维码
< <上一篇
下一篇>>