実現したいこと
グラフ①に、グラフ②を重ねて表示したいが、グラフ①とグラフ②に分かれて出てしまう。
重なって出てこないため、いまは2つをそれぞれ fig.show() しています。
numpy, plotly, matpltlib でグラフをかいてます
前提
為替のグラフを作っていますが、本題はグラフをうまく表示できないことなので、為替のことは参考程度に記載します
↓
↓
↓
https://datapowernow.hatenablog.com/entry/2023/09/21/000737
http://home.netyou.jp/55/dpn/flag_pattern_detect_demo.html
https://www.youtube.com/watch?v=b5m7BZAHysk&list=LL&index=1
上記のコードを参考に、為替チャートのフラッグパターン(グラフ①)を作成、ボリンジャーバンド(グラフ②)を追加したいと思い、書きましたが、2つが重なって出てほしいです。
発生している問題・エラーメッセージ
/opt/homebrew/lib/python3.11/site-packages/_plotly_utils/basevalidators.py:105: FutureWarning: The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result
該当のソースコード
クソ長コードで申し訳ありませんが、概要を先に書きます
グラフ1の表示 → line 138
グラフ2の表示 → line 206
Python
1#グラフ1 2import numpy as np 3import pandas as pd 4import plotly.graph_objects as go 5from plotly.subplots import make_subplots 6from scipy.stats import linregress 7import matplotlib.pyplot as plt 8import yfinance as yf 9import pandas_ta as ta 10import matplotlib.dates as mdates 11from datetime import datetime, timedelta 12from pandas_datareader import data as web 13import mplfinance as mpf 14import datetime as dt 15 16def pivotid(df1, l, n1, n2): #n1 n2 before and after candle l 17 if l-n1 < 0 or l+n2 >= len(df1): 18 return 0 19 20 pividlow=1 21 pividhigh=1 22 for i in range(l-n1, l+n2+1): 23 if(df1.low[l]>df1.low[i]): 24 pividlow=0 25 if(df1.high[l]<df1.high[i]): 26 pividhigh=0 27 if pividlow and pividhigh: 28 return 3 29 elif pividlow: 30 return 1 31 elif pividhigh: 32 return 2 33 else: 34 return 0 35 36def detect_flag(df1, candle, backcandles, window, plot_flag=False): 37 """ 38 Attention! window should always be greater than the pivot window! to avoid look ahead bias 39 """ 40 localdf = df1[candle-backcandles-window:candle-window] 41 highs = localdf[localdf['pivot'] == 2].high.tail(3).values 42 idxhighs = localdf[localdf['pivot'] == 2].high.tail(3).index 43 lows = localdf[localdf['pivot'] == 1].low.tail(3).values 44 idxlows = localdf[localdf['pivot'] == 1].low.tail(3).index 45 46 if len(highs) == 3 and len(lows) == 3: 47 order_condition = ( 48 (idxlows[0] < idxhighs[0] 49 < idxlows[1] < idxhighs[1] 50 < idxlows[2] < idxhighs[2]) 51 or 52 (idxhighs[0] < idxlows[0] 53 < idxhighs[1] < idxlows[1] 54 < idxhighs[2] < idxlows[2]) ) 55 56 slmin, intercmin, rmin, _, _ = linregress(idxlows, lows) 57 slmax, intercmax, rmax, _, _ = linregress(idxhighs, highs) 58 59 if (order_condition 60 and (rmax*rmax)>=0.9 61 and (rmin*rmin)>=0.9 62 and slmin>=0.0001 63 and slmax<=-0.0001): 64 #and ((abs(slmin)-abs(slmax))/abs(slmax)) < 0.05): 65 66 if plot_flag: 67 fig = go.Figure(data=[go.Candlestick(x=localdf.index, 68 open=localdf['open'], 69 high=localdf['high'], 70 low=localdf['low'], 71 close=localdf['close'])]) 72 73 fig.add_scatter(x=localdf.index, y=localdf['pointpos'], mode="markers", 74 marker=dict(size=10, color="MediumPurple"), 75 name="pivot") 76 fig.add_trace(go.Scatter(x=idxlows, y=slmin*idxlows + intercmin, mode='lines', name='min slope')) 77 fig.add_trace(go.Scatter(x=idxhighs, y=slmax*idxhighs + intercmax, mode='lines', name='max slope')) 78 fig.update_layout( 79 xaxis_rangeslider_visible=False, 80 plot_bgcolor='white', # change the background to white 81 xaxis=dict(showgrid=True, gridcolor='white'), # change the x-axis grid to white 82 yaxis=dict(showgrid=True, gridcolor='white') # change the y-axis grid to white 83 ) 84 fig.show() 85 86 return 1 87 88 return 0 89 90 91yf.pdr_override() # yfinanceライブラリでpandas_datareaderの関数をオーバーライド 92start = dt.datetime.now() - dt.timedelta(hours=12) 93end = dt.datetime.now() 94symbol = 'USDJPY=X' 95df = web.get_data_yahoo(tickers=symbol,start=start,end=end,period='12h',interval='1m') 96df = df.sort_index() 97df = df.reset_index(names='Gmt time') 98df = df.drop(columns=['Adj Close', 'Volume']) 99df.columns=['time', 'open', 'high', 'low', 'close'] 100#Check if NA values are in data 101#df=df[df['volume']!=0] 102df.reset_index(drop=True, inplace=True) 103df.isna().sum() 104 105 106# 無名関数(lambda関数)を定義しています。この関数は、DataFrameの各行を入力として受け取り、その行の名前(インデックス)を x.name としてアクセスします。 107df['pivot'] = df.apply(lambda x: pivotid(df, x.name, 3, 3), axis=1) 108df.head() 109 110 111df['flag'] = df.apply(lambda x: detect_flag(df, x.name, 35, 3), axis=1) 112df.tail() 113 114 115df_detect = df[df['flag']!=0] 116df_detect 117 118 119def pointpos(x): 120 if x['pivot']==1: 121 return x['low']-1e-3 122 elif x['pivot']==2: 123 return x['high']+1e-3 124 else: 125 return np.nan 126 127df['pointpos'] = df.apply(lambda row: pointpos(row), axis=1) 128 129fig = go.Figure(data=[go.Candlestick(x=df.index, 130 open=df['open'], 131 high=df['high'], 132 low=df['low'], 133 close=df['close'])]) 134 135fig.add_scatter(x=df.index, y=df['pointpos'], mode="markers", 136 marker=dict(size=5, color="MediumPurple"), 137 name="pivot") 138fig.show() 139 140#グラフ2 141# 移動平均線 142df["SMA25"] = df["close"].rolling(window=25).mean() 143# 標準偏差 144df["std"] = df["close"].rolling(window=25).std() 145# ボリンジャーバンド 146df["2upper"] = df["SMA25"] + (2 * df["std"]) 147df["2lower"] = df["SMA25"] - (2 * df["std"]) 148df["3upper"] = df["SMA25"] + (3 * df["std"]) 149df["3lower"] = df["SMA25"] - (3 * df["std"]) 150# 非表示にする日付をリストアップ 151d_all = pd.date_range(start=df['time'].iloc[0],end=df['time'].iloc[-1]) 152d_obs = [d.strftime("%Y-%m-%d") for d in df['time']] 153d_breaks = [d for d in d_all.strftime("%Y-%m-%d").tolist() if not d in d_obs] 154 155# figを定義 156fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.05, row_width=[0.2, 0.7], x_title="Date") 157 158# Candlestick 159fig.add_trace( 160 go.Candlestick(x=df["time"], open=df["open"], high=df["high"], low=df["low"], close=df["close"], name="OHLC", showlegend=False), 161 row=1, col=1 162) 163 164# SMA 165fig.add_trace(go.Scatter(x=df["time"], y=df["SMA25"], name="SMA25", mode="lines"), row=1, col=1) 166 167# ボリンジャーバンド 168fig.add_trace( 169 go.Scatter(x=df["time"], y=df["2upper"], name="2σ", line=dict(width=1, color="red")), 170 row=1, col=1 171) 172fig.add_trace( 173 go.Scatter(x=df["time"], y=df["2lower"], line=dict(width=1, color="red"), showlegend=False), 174 row=1, col=1 175) 176 177fig.add_trace( 178 go.Scatter(x=df["time"], y=df["3upper"], name="3σ", line=dict(width=1, color="blue")), 179 row=1, col=1 180) 181fig.add_trace( 182 go.Scatter(x=df["time"], y=df["3lower"], line=dict(width=1, color="blue"), showlegend=False), 183 row=1, col=1 184) 185 186# Layout 187fig.update_layout( 188 title={ 189 "text": "日経平均の日足チャート", 190 "y":0.9, 191 "x":0.5, 192 } 193) 194 195# y軸名を定義 196fig.update_yaxes(title_text="株価", row=1, col=1) 197fig.update_yaxes(title_text="出来高", row=2, col=1) 198fig.update_yaxes(title_text="乖離率", row=3, col=1) 199 200# 不要な日付を非表示にする 201fig.update_xaxes( 202 rangebreaks=[dict(values=d_breaks)] 203) 204 205fig.update(layout_xaxis_rangeslider_visible=False) 206fig.show() 207
試したこと
line 156 のfig を新たに定義しているのがまずいかと思い、試行錯誤しました
定義を消すと、グラフ2はそれ以上表示されません。
/opt/homebrew/lib/python3.11/site-packages/_plotly_utils/basevalidators.py:105: FutureWarning: The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result Traceback (most recent call last): File "/opt/homebrew/lib/python3.11/site-packages/plotly/basedatatypes.py", line 2338, in _validate_get_grid_ref raise AttributeError("_grid_ref") AttributeError: _grid_ref During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/Users/uta/Documents/code/fx/bb.py", line 163, in <module> fig.add_trace( File "/opt/homebrew/lib/python3.11/site-packages/plotly/graph_objs/_figure.py", line 900, in add_trace return super(Figure, self).add_trace( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/lib/python3.11/site-packages/plotly/basedatatypes.py", line 2108, in add_trace return self.add_traces( ^^^^^^^^^^^^^^^^ File "/opt/homebrew/lib/python3.11/site-packages/plotly/graph_objs/_figure.py", line 980, in add_traces return super(Figure, self).add_traces( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/lib/python3.11/site-packages/plotly/basedatatypes.py", line 2238, in add_traces self._set_trace_grid_position(trace, row, col, secondary_y) File "/opt/homebrew/lib/python3.11/site-packages/plotly/basedatatypes.py", line 2329, in _set_trace_grid_position grid_ref = self._validate_get_grid_ref() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/lib/python3.11/site-packages/plotly/basedatatypes.py", line 2340, in _validate_get_grid_ref raise Exception( Exception: In order to reference traces by row and column, you must first use plotly.tools.make_subplots to create the figure with a subplot grid.
補足情報(FW/ツールのバージョンなど)
macOS

回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2023/10/24 04:44
2023/10/24 15:33
2023/11/05 15:31