import pandas as pd import matplotlib.pyplot as plt from pymongo import MongoClient, DESCENDING import random # ========== 全局数据库连接(按需修改) ========== DB_CONN = MongoClient('127.0.0.1', 27017)['stock_db'] plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False # ========== 缺失的辅助函数 ========== def get_trading_dates(start_date, end_date): res = DB_CONN['daily'].distinct('date', { 'code':'000300','index':True, 'date':{'$gte':start_date,'$lte':end_date} }) res.sort() return res def find_out_stocks(last_codes, curr_codes): return list(set(last_codes) - set(curr_codes)) def stock_pool(begin_date, end_date): all_trade_dates = get_trading_dates(begin_date, end_date) adjust_dates = [] date_codes_dict = {} month_map = {} for d in all_trade_dates: ym = d[:7] if ym not in month_map: month_map[ym] = d adjust_dates = sorted(list(month_map.values())) all_stock_codes = DB_CONN['daily'].distinct('code',{'index':False}) random.seed(42) for adj_date in adjust_dates: select_codes = random.sample(all_stock_codes, min(30, len(all_stock_codes))) date_codes_dict[adj_date] = select_codes return adjust_dates, date_codes_dict # ========== 原信号函数 ========== def compare_close_2_ma_10(dailies): current_daily = dailies[9] close_sum = 0 for daily in dailies: if 'is_trading' not in daily or daily['is_trading'] is False: return None close_sum += daily['close'] ma_10 = close_sum / 10 post_adjusted_close = current_daily['close'] differ = post_adjusted_close - ma_10 if differ > 0: return 1 elif differ < 0: return -1 else: return 0 def is_k_up_break_ma10(code, _date): current_daily = DB_CONN['daily_hfq'].find_one( {'code': code, 'date': _date, 'is_trading': True}) if current_daily is None: print('计算信号,K线上穿MA10,当日没有K线,股票 %s,日期:%s' % (code, _date), flush=True) return False daily_cursor = DB_CONN['daily_hfq'].find( {'code': code, 'date': {'$lte': _date}}, limit=11, sort=[('date', DESCENDING)], projection={'code': True, 'close': True, 'is_trading': True} ) dailies = [x for x in daily_cursor] if len(dailies) < 11: print('计算信号,K线上穿MA10,前期K线不足,股票 %s,日期:%s' % (code, _date), flush=True) return False dailies.reverse() last_close_2_last_ma10 = compare_close_2_ma_10(dailies[0:10]) current_close_2_current_ma10 = compare_close_2_ma_10(dailies[1:]) print('计算信号,K线上穿MA10,股票:%s,日期:%s, 前一日 %s,当日:%s' % (code, _date, str(last_close_2_last_ma10), str(current_close_2_current_ma10)), flush=True) if last_close_2_last_ma10 is None or current_close_2_current_ma10 is None: return False is_break = (last_close_2_last_ma10 <= 0) & (current_close_2_current_ma10 == 1) print('计算信号,K线上穿MA10,股票:%s,日期:%s, 前一日 %s,当日:%s,突破:%s' % (code, _date, str(last_close_2_last_ma10), str(current_close_2_current_ma10), str(is_break)), flush=True) return is_break def is_k_down_break_ma10(code, _date): # 修复bug:统一查询daily_hfq current_daily = DB_CONN['daily_hfq'].find_one( {'code': code, 'date': _date, 'is_trading': True}) if current_daily is None: print('计算信号,K线下穿MA10,当日没有K线,股票 %s,日期:%s' % (code, _date), flush=True) return False daily_cursor = DB_CONN['daily_hfq'].find( {'code': code, 'date': {'$lte': _date}}, limit=11, sort=[('date', DESCENDING)], projection={'code': True, 'close': True, 'is_trading': True} ) dailies = [x for x in daily_cursor] if len(dailies) < 11: print('计算信号,K线下穿MA10,前期K线不足,股票 %s,日期:%s' % (code, _date), flush=True) return False dailies.reverse() last_close_2_last_ma10 = compare_close_2_ma_10(dailies[0:10]) current_close_2_current_ma10 = compare_close_2_ma_10(dailies[1:]) if last_close_2_last_ma10 is None or current_close_2_current_ma10 is None: return False is_break = (last_close_2_last_ma10 >= 0) & (current_close_2_current_ma10 == -1) print('计算信号,K线下穿MA10,股票:%s,日期:%s, 前一日 %s,当日:%s, 突破:%s' % (code, _date, str(last_close_2_last_ma10), str(current_close_2_current_ma10), str(is_break)), flush=True) return is_break # ========== 指标计算函数(优化最大回撤算法) ========== def compute_drawdown(net_values): """优化版最大回撤:滚动历史峰值""" net_series = net_values.copy() max_so_far = net_series.iloc[0] max_dd = 0.0 for nav in net_series: if nav > max_so_far: max_so_far = nav dd = 1 - nav / max_so_far if dd > max_dd: max_dd = dd return round(max_dd,4) def compute_annual_profit(trading_days, end_nav): annual_profit = 0 if trading_days > 0: year = trading_days / 245 annual_profit = pow(end_nav, 1/year) - 1 return round(annual_profit * 100,2) def compute_sharpe_ratio(net_values): trading_days = len(net_values) # 日收益率 daily_ret = net_values.pct_change().dropna() ann_return = compute_annual_profit(trading_days, net_values.iloc[-1]) # 日收益标准差 * sqrt(245) 年化波动 ret_std = daily_ret.std() ann_std = ret_std * (245**0.5) rf = 4.75 / 100 # 年化无风险4.75% sharpe = ((ann_return/100) - rf) / ann_std if ann_std !=0 else 0 return ann_return, round(sharpe,2) # ========== 回测主函数 ========== def backtest(begin_date, end_date): cash = 1E7 single_position = 2E5 df_profit = pd.DataFrame(columns=['net_value', 'profit', 'hs300']) all_dates = get_trading_dates(begin_date, end_date) if not all_dates: print('无交易日数据') return # 沪深300期初点位 hs300_begin_value = DB_CONN['daily'].find_one( {'code': '000300', 'index': True, 'date': all_dates[0]}, projection={'close': True})['close'] adjust_dates, date_codes_dict = stock_pool(begin_date, end_date) last_phase_codes = None this_phase_codes = None to_be_sold_codes = set() to_be_bought_codes = set() holding_code_dict = dict() last_date = None for _date in all_dates: print('Backtest at %s.' % _date) before_sell_holding_codes = list(holding_code_dict.keys()) # 除权除息调整持仓 if last_date is not None and len(before_sell_holding_codes) > 0: last_daily_cursor = DB_CONN['daily'].find( {'code': {'$in': before_sell_holding_codes}, 'date': last_date, 'index': False}, projection={'code': True, 'au_factor': True}) code_last_aufactor_dict = dict([(d['code'], d['au_factor']) for d in last_daily_cursor]) current_daily_cursor = DB_CONN['daily'].find( {'code': {'$in': before_sell_holding_codes}, 'date': _date, 'index': False}, projection={'code': True, 'au_factor': True}) for curr in current_daily_cursor: curr_auf = curr['au_factor'] c = curr['code'] last_vol = holding_code_dict[c]['volume'] if c in code_last_aufactor_dict: last_auf = code_last_aufactor_dict[c] new_vol = int(last_vol * (curr_auf / last_auf)) holding_code_dict[c]['volume'] = new_vol print('持仓量调整:%s, %6d, %10.6f, %6d, %10.6f' % (c, last_vol, last_auf, new_vol, curr_auf)) # 开盘卖出 print('待卖股票池:', to_be_sold_codes, flush=True) if len(to_be_sold_codes) >0: sell_daily_cursor = DB_CONN['daily'].find( {'code':{'$in':list(to_be_sold_codes)},'date':_date,'index':False,'is_trading':True}, projection={'open':True,'code':True}) for sd in sell_daily_cursor: c = sd['code'] if c in before_sell_holding_codes: hold = holding_code_dict[c] vol = hold['volume'] sell_price = sd['open'] sell_amt = vol * sell_price cash += sell_amt cost = hold['cost'] pct = (sell_amt-cost)/cost*100 print('卖出 %s, %6d, %6.2f, %8.2f, %4.2f' % (c,vol,sell_price,sell_amt,pct)) del holding_code_dict[c] to_be_sold_codes.remove(c) print('卖出后,现金: %10.2f' % cash) # 开盘买入 print('待买股票池:', to_be_bought_codes, flush=True) if len(to_be_bought_codes)>0: buy_daily_cursor = DB_CONN['daily'].find( {'code':{'$in':list(to_be_bought_codes)},'date':_date,'is_trading':True,'index':False}, projection={'code':True,'open':True}) for bd in buy_daily_cursor: if cash > single_position: bp = bd['open'] c = bd['code'] vol = int(int(single_position / bp)/100)*100 buy_amt = bp * vol cash -= buy_amt holding_code_dict[c] = {'volume':vol,'cost':buy_amt,'last_value':buy_amt} print('买入 %s, %6d, %6.2f, %8.2f' % (c,vol,bp,buy_amt)) print('买入后,现金: %10.2f' % cash) holding_codes = list(holding_code_dict.keys()) # 股票池调仓 if _date in adjust_dates: print('股票池调整日:%s,备选股票列表:' % _date, flush=True) if this_phase_codes is not None: last_phase_codes = this_phase_codes this_phase_codes = date_codes_dict[_date] print(this_phase_codes, flush=True) if last_phase_codes is not None: out_codes = find_out_stocks(last_phase_codes, this_phase_codes) for oc in out_codes: if oc in holding_code_dict: to_be_sold_codes.add(oc) # 生成次日买卖信号 for hc in holding_codes: if is_k_down_break_ma10(hc, _date): to_be_sold_codes.add(hc) to_be_bought_codes.clear() if this_phase_codes is not None: for code in this_phase_codes: if code not in holding_codes and is_k_up_break_ma10(code, _date): to_be_bought_codes.add(code) # 计算当日总资产 total_mv = 0 hold_daily = DB_CONN['daily'].find( {'code':{'$in':holding_codes},'date':_date}, projection={'close':True,'code':True}) for hd in hold_daily: c = hd['code'] hold = holding_code_dict[c] mv = hd['close'] * hold['volume'] total_mv += mv profit_all = (mv - hold['cost'])/hold['cost']*100 day_p = (mv - hold['last_value'])/hold['last_value']*100 hold['last_value'] = mv print('持仓: %s, %10.2f, %4.2f, %4.2f' % (c, mv, profit_all, day_p)) total_cap = total_mv + cash # 沪深300当日点位 hs300_curr = DB_CONN['daily'].find_one( {'code':'000300','index':True,'date':_date},projection={'close':True})['close'] print('收盘后,现金: %10.2f, 总资产: %10.2f' % (cash, total_cap)) last_date = _date df_profit.loc[_date] = { 'net_value': round(total_cap/1e7,4), 'profit': round((total_cap-1e7)/1e7*100,2), 'hs300': round((hs300_curr-hs300_begin_value)/hs300_begin_value*100,2) } # 回测结束,指标计算 nav_series = df_profit['net_value'] max_dd = compute_drawdown(nav_series) ann_ret, sharpe = compute_sharpe_ratio(nav_series) print('\n======回测结果汇总======') print(f'回测区间:{begin_date} ~ {end_date}') print(f'年化收益率:{ann_ret:.2f} %') print(f'最大回撤:{max_dd*100:.2f} %') print(f'夏普比率:{sharpe:.2f}') # 绘图:净值曲线对比+累计收益对比 fig, (ax1,ax2) = plt.subplots(2,1,figsize=(12,8)) # 子图1:净值 ax1.plot(df_profit.index, df_profit['net_value'], label='策略净值', color='red') # 构造沪深300净值 df_profit['hs300_nav'] = 1 + df_profit['hs300']/100 ax1.plot(df_profit.index, df_profit['hs300_nav'], label='沪深300净值', color='blue') ax1.set_title('策略净值 vs 沪深300净值') ax1.legend() ax1.grid() # 子图2:累计收益% ax2.plot(df_profit.index, df_profit['profit'], label='策略累计收益%', color='red') ax2.plot(df_profit.index, df_profit['hs300'], label='沪深300累计收益%', color='blue') ax2.set_title('累计收益对比(%)') ax2.legend() ax2.grid() plt.tight_layout() plt.show() return df_profit, ann_ret, max_dd, sharpe if __name__ == "__main__": df_result, ann_rate, maxdd, shar = backtest('2015-01-01', '2015-12-31') |