From abfc80c8b922e465604bc917dd00c1eb02e1c2c2 Mon Sep 17 00:00:00 2001 From: jibi Date: Sun, 28 Nov 2021 15:44:50 -0800 Subject: [PATCH 1/7] update quandl download after NASDAQ acquisition. --- pyalgotrade/tools/quandl.py | 72 +++++++++++++++++++++++++------------ 1 file changed, 49 insertions(+), 23 deletions(-) diff --git a/pyalgotrade/tools/quandl.py b/pyalgotrade/tools/quandl.py index 5014a2b70..90f845394 100644 --- a/pyalgotrade/tools/quandl.py +++ b/pyalgotrade/tools/quandl.py @@ -31,10 +31,15 @@ import pyalgotrade.logger -# http://www.quandl.com/help/api +# https://docs.data.nasdaq.com/docs/in-depth-usage + +base_url = 'https://data.nasdaq.com/api/v3/datasets' + def download_csv(sourceCode, tableCode, begin, end, frequency, authToken): - url = "http://www.quandl.com/api/v1/datasets/%s/%s.csv" % (sourceCode, tableCode) + + url = f'{base_url}/{sourceCode}/{tableCode}/data.csv' + params = { "trim_start": begin.strftime("%Y-%m-%d"), "trim_end": end.strftime("%Y-%m-%d"), @@ -61,7 +66,8 @@ def download_daily_bars(sourceCode, tableCode, year, csvFile, authToken=None): :type authToken: string. """ - bars = download_csv(sourceCode, tableCode, datetime.date(year, 1, 1), datetime.date(year, 12, 31), "daily", authToken) + bars = download_csv(sourceCode, tableCode, datetime.date( + year, 1, 1), datetime.date(year, 12, 31), "daily", authToken) f = open(csvFile, "w") f.write(bars) f.close() @@ -82,8 +88,10 @@ def download_weekly_bars(sourceCode, tableCode, year, csvFile, authToken=None): :type authToken: string. """ - begin = dt.get_first_monday(year) - datetime.timedelta(days=1) # Start on a sunday - end = dt.get_last_monday(year) - datetime.timedelta(days=1) # Start on a sunday + begin = dt.get_first_monday( + year) - datetime.timedelta(days=1) # Start on a sunday + end = dt.get_last_monday( + year) - datetime.timedelta(days=1) # Start on a sunday bars = download_csv(sourceCode, tableCode, begin, end, "weekly", authToken) f = open(csvFile, "w") f.write(bars) @@ -145,37 +153,51 @@ def build_feed(sourceCode, tableCodes, fromYear, toYear, storage, frequency=bar. for year in range(fromYear, toYear+1): for tableCode in tableCodes: - fileName = os.path.join(storage, "%s-%s-%d-quandl.csv" % (sourceCode, tableCode, year)) + fileName = os.path.join( + storage, "%s-%s-%d-quandl.csv" % (sourceCode, tableCode, year)) if not os.path.exists(fileName) or forceDownload: - logger.info("Downloading %s %d to %s" % (tableCode, year, fileName)) + logger.info("Downloading %s %d to %s" % + (tableCode, year, fileName)) try: if frequency == bar.Frequency.DAY: - download_daily_bars(sourceCode, tableCode, year, fileName, authToken) + download_daily_bars( + sourceCode, tableCode, year, fileName, authToken) else: assert frequency == bar.Frequency.WEEK, "Invalid frequency" - download_weekly_bars(sourceCode, tableCode, year, fileName, authToken) + download_weekly_bars( + sourceCode, tableCode, year, fileName, authToken) except Exception as e: if skipErrors: logger.error(str(e)) continue else: raise e - ret.addBarsFromCSV(tableCode, fileName, skipMalformedBars=skipMalformedBars) + ret.addBarsFromCSV(tableCode, fileName, + skipMalformedBars=skipMalformedBars) return ret def main(): parser = argparse.ArgumentParser(description="Quandl utility") - parser.add_argument("--auth-token", required=False, help="An authentication token needed if you're doing more than 50 calls per day") - parser.add_argument("--source-code", required=True, help="The dataset source code") - parser.add_argument("--table-code", required=True, help="The dataset table code") - parser.add_argument("--from-year", required=True, type=int, help="The first year to download") - parser.add_argument("--to-year", required=True, type=int, help="The last year to download") - parser.add_argument("--storage", required=True, help="The path were the files will be downloaded to") - parser.add_argument("--force-download", action='store_true', help="Force downloading even if the files exist") - parser.add_argument("--ignore-errors", action='store_true', help="True to keep on downloading files in case of errors") - parser.add_argument("--frequency", default="daily", choices=["daily", "weekly"], help="The frequency of the bars. Only daily or weekly are supported") + parser.add_argument("--auth-token", required=False, + help="An authentication token needed if you're doing more than 50 calls per day") + parser.add_argument("--source-code", required=True, + help="The dataset source code") + parser.add_argument("--table-code", required=True, + help="The dataset table code") + parser.add_argument("--from-year", required=True, + type=int, help="The first year to download") + parser.add_argument("--to-year", required=True, type=int, + help="The last year to download") + parser.add_argument("--storage", required=True, + help="The path were the files will be downloaded to") + parser.add_argument("--force-download", action='store_true', + help="Force downloading even if the files exist") + parser.add_argument("--ignore-errors", action='store_true', + help="True to keep on downloading files in case of errors") + parser.add_argument("--frequency", default="daily", choices=[ + "daily", "weekly"], help="The frequency of the bars. Only daily or weekly are supported") args = parser.parse_args() @@ -186,15 +208,19 @@ def main(): os.mkdir(args.storage) for year in range(args.from_year, args.to_year+1): - fileName = os.path.join(args.storage, "%s-%s-%d-quandl.csv" % (args.source_code, args.table_code, year)) + fileName = os.path.join(args.storage, "%s-%s-%d-quandl.csv" % + (args.source_code, args.table_code, year)) if not os.path.exists(fileName) or args.force_download: - logger.info("Downloading %s %d to %s" % (args.table_code, year, fileName)) + logger.info("Downloading %s %d to %s" % + (args.table_code, year, fileName)) try: if args.frequency == "daily": - download_daily_bars(args.source_code, args.table_code, year, fileName, args.auth_token) + download_daily_bars( + args.source_code, args.table_code, year, fileName, args.auth_token) else: assert args.frequency == "weekly", "Invalid frequency" - download_weekly_bars(args.source_code, args.table_code, year, fileName, args.auth_token) + download_weekly_bars( + args.source_code, args.table_code, year, fileName, args.auth_token) except Exception as e: if args.ignore_errors: logger.error(str(e)) From bc2428f8858b4641dff2789165230305ebda8e58 Mon Sep 17 00:00:00 2001 From: jibi Date: Sat, 4 Dec 2021 11:59:52 -0800 Subject: [PATCH 2/7] initial commit: copying bitstamp files as base for alpaca. --- pyalgotrade/alpaca/__init__.py | 19 ++ pyalgotrade/alpaca/barfeed.py | 25 +++ pyalgotrade/alpaca/broker.py | 123 ++++++++++++ pyalgotrade/alpaca/common.py | 31 +++ pyalgotrade/alpaca/httpclient.py | 237 ++++++++++++++++++++++ pyalgotrade/alpaca/livebroker.py | 332 +++++++++++++++++++++++++++++++ pyalgotrade/alpaca/livefeed.py | 267 +++++++++++++++++++++++++ pyalgotrade/alpaca/wsclient.py | 198 ++++++++++++++++++ 8 files changed, 1232 insertions(+) create mode 100644 pyalgotrade/alpaca/__init__.py create mode 100644 pyalgotrade/alpaca/barfeed.py create mode 100644 pyalgotrade/alpaca/broker.py create mode 100644 pyalgotrade/alpaca/common.py create mode 100644 pyalgotrade/alpaca/httpclient.py create mode 100644 pyalgotrade/alpaca/livebroker.py create mode 100644 pyalgotrade/alpaca/livefeed.py create mode 100644 pyalgotrade/alpaca/wsclient.py diff --git a/pyalgotrade/alpaca/__init__.py b/pyalgotrade/alpaca/__init__.py new file mode 100644 index 000000000..346270fc8 --- /dev/null +++ b/pyalgotrade/alpaca/__init__.py @@ -0,0 +1,19 @@ +# PyAlgoTrade +# +# Copyright 2011-2018 Gabriel Martin Becedillas Ruiz +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +.. moduleauthor:: Gabriel Martin Becedillas Ruiz +""" diff --git a/pyalgotrade/alpaca/barfeed.py b/pyalgotrade/alpaca/barfeed.py new file mode 100644 index 000000000..4abd536b3 --- /dev/null +++ b/pyalgotrade/alpaca/barfeed.py @@ -0,0 +1,25 @@ +# PyAlgoTrade +# +# Copyright 2011-2018 Gabriel Martin Becedillas Ruiz +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +.. moduleauthor:: Gabriel Martin Becedillas Ruiz +""" + + +from pyalgotrade.bitstamp import livefeed + + +LiveTradeFeed = livefeed.LiveTradeFeed diff --git a/pyalgotrade/alpaca/broker.py b/pyalgotrade/alpaca/broker.py new file mode 100644 index 000000000..1a1717aab --- /dev/null +++ b/pyalgotrade/alpaca/broker.py @@ -0,0 +1,123 @@ +# PyAlgoTrade +# +# Copyright 2011-2018 Gabriel Martin Becedillas Ruiz +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +.. moduleauthor:: Gabriel Martin Becedillas Ruiz +""" + + +from pyalgotrade import broker +from pyalgotrade.broker import backtesting +from pyalgotrade.bitstamp import common +from pyalgotrade.bitstamp import livebroker + + +LiveBroker = livebroker.LiveBroker + +# In a backtesting or paper-trading scenario the BacktestingBroker dispatches events while processing events from the +# BarFeed. +# It is guaranteed to process BarFeed events before the strategy because it connects to BarFeed events before the +# strategy. + + +class BacktestingBroker(backtesting.Broker): + MIN_TRADE_USD = 5 + + """A Bitstamp backtesting broker. + + :param cash: The initial amount of cash. + :type cash: int/float. + :param barFeed: The bar feed that will provide the bars. + :type barFeed: :class:`pyalgotrade.barfeed.BarFeed` + :param fee: The fee percentage for each order. Defaults to 0.25%. + :type fee: float. + + .. note:: + * Only limit orders are supported. + * Orders are automatically set as **goodTillCanceled=True** and **allOrNone=False**. + * BUY_TO_COVER orders are mapped to BUY orders. + * SELL_SHORT orders are mapped to SELL orders. + """ + + def __init__(self, cash, barFeed, fee=0.0025): + commission = backtesting.TradePercentage(fee) + super(BacktestingBroker, self).__init__(cash, barFeed, commission) + + def getInstrumentTraits(self, instrument): + return common.BTCTraits() + + def submitOrder(self, order): + if order.isInitial(): + # Override user settings based on Bitstamp limitations. + order.setAllOrNone(False) + order.setGoodTillCanceled(True) + return super(BacktestingBroker, self).submitOrder(order) + + def createMarketOrder(self, action, instrument, quantity, onClose=False): + raise Exception("Market orders are not supported") + + def createLimitOrder(self, action, instrument, limitPrice, quantity): + if instrument != common.btc_symbol: + raise Exception("Only BTC instrument is supported") + + if action == broker.Order.Action.BUY_TO_COVER: + action = broker.Order.Action.BUY + elif action == broker.Order.Action.SELL_SHORT: + action = broker.Order.Action.SELL + + if limitPrice * quantity < BacktestingBroker.MIN_TRADE_USD: + raise Exception("Trade must be >= %s" % (BacktestingBroker.MIN_TRADE_USD)) + + if action == broker.Order.Action.BUY: + # Check that there is enough cash. + fee = self.getCommission().calculate(None, limitPrice, quantity) + cashRequired = limitPrice * quantity + fee + if cashRequired > self.getCash(False): + raise Exception("Not enough cash") + elif action == broker.Order.Action.SELL: + # Check that there are enough coins. + if quantity > self.getShares(common.btc_symbol): + raise Exception("Not enough %s" % (common.btc_symbol)) + else: + raise Exception("Only BUY/SELL orders are supported") + + return super(BacktestingBroker, self).createLimitOrder(action, instrument, limitPrice, quantity) + + def createStopOrder(self, action, instrument, stopPrice, quantity): + raise Exception("Stop orders are not supported") + + def createStopLimitOrder(self, action, instrument, stopPrice, limitPrice, quantity): + raise Exception("Stop limit orders are not supported") + + +class PaperTradingBroker(BacktestingBroker): + """A Bitstamp paper trading broker. + + :param cash: The initial amount of cash. + :type cash: int/float. + :param barFeed: The bar feed that will provide the bars. + :type barFeed: :class:`pyalgotrade.barfeed.BarFeed` + :param fee: The fee percentage for each order. Defaults to 0.5%. + :type fee: float. + + .. note:: + * Only limit orders are supported. + * Orders are automatically set as **goodTillCanceled=True** and **allOrNone=False**. + * BUY_TO_COVER orders are mapped to BUY orders. + * SELL_SHORT orders are mapped to SELL orders. + """ + + pass diff --git a/pyalgotrade/alpaca/common.py b/pyalgotrade/alpaca/common.py new file mode 100644 index 000000000..b42b87bf5 --- /dev/null +++ b/pyalgotrade/alpaca/common.py @@ -0,0 +1,31 @@ +# PyAlgoTrade +# +# Copyright 2011-2018 Gabriel Martin Becedillas Ruiz +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +.. moduleauthor:: Gabriel Martin Becedillas Ruiz +""" + +import pyalgotrade.logger +from pyalgotrade import broker + + +logger = pyalgotrade.logger.getLogger("bitstamp") +btc_symbol = "BTC" + + +class BTCTraits(broker.InstrumentTraits): + def roundQuantity(self, quantity): + return round(quantity, 8) diff --git a/pyalgotrade/alpaca/httpclient.py b/pyalgotrade/alpaca/httpclient.py new file mode 100644 index 000000000..e51f4e8da --- /dev/null +++ b/pyalgotrade/alpaca/httpclient.py @@ -0,0 +1,237 @@ +# PyAlgoTrade +# +# Copyright 2011-2018 Gabriel Martin Becedillas Ruiz +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +.. moduleauthor:: Gabriel Martin Becedillas Ruiz +""" + +import time +import datetime +import hmac +import hashlib +import requests +import threading + +from pyalgotrade.utils import dt +from pyalgotrade.bitstamp import common + +import logging +logging.getLogger("requests").setLevel(logging.ERROR) + + +def parse_datetime(dateTime): + try: + ret = datetime.datetime.strptime(dateTime, "%Y-%m-%d %H:%M:%S") + except ValueError: + ret = datetime.datetime.strptime(dateTime, "%Y-%m-%d %H:%M:%S.%f") + return dt.as_utc(ret) + + +class NonceGenerator(object): + def __init__(self): + self.__prev = None + + def getNext(self): + ret = int(time.time()) + if self.__prev is not None and ret <= self.__prev: + ret = self.__prev + 1 + self.__prev = ret + return ret + + +class AccountBalance(object): + def __init__(self, jsonDict): + self.__jsonDict = jsonDict + + def getDict(self): + return self.__jsonDict + + def getUSDAvailable(self): + return float(self.__jsonDict["usd_available"]) + + def getBTCAvailable(self): + return float(self.__jsonDict["btc_available"]) + + +class Order(object): + def __init__(self, jsonDict): + self.__jsonDict = jsonDict + + def getDict(self): + return self.__jsonDict + + def getId(self): + return int(self.__jsonDict["id"]) + + def isBuy(self): + return self.__jsonDict["type"] == 0 + + def isSell(self): + return self.__jsonDict["type"] == 1 + + def getPrice(self): + return float(self.__jsonDict["price"]) + + def getAmount(self): + return float(self.__jsonDict["amount"]) + + def getDateTime(self): + return parse_datetime(self.__jsonDict["datetime"]) + + +class UserTransaction(object): + def __init__(self, jsonDict): + self.__jsonDict = jsonDict + + def getDict(self): + return self.__jsonDict + + def getBTC(self): + return float(self.__jsonDict["btc"]) + + def getBTCUSD(self): + return float(self.__jsonDict["btc_usd"]) + + def getDateTime(self): + return parse_datetime(self.__jsonDict["datetime"]) + + def getFee(self): + return float(self.__jsonDict["fee"]) + + def getId(self): + return int(self.__jsonDict["id"]) + + def getOrderId(self): + return int(self.__jsonDict["order_id"]) + + def getUSD(self): + return float(self.__jsonDict["usd"]) + + +class HTTPClient(object): + USER_AGENT = "PyAlgoTrade" + REQUEST_TIMEOUT = 30 + + class UserTransactionType: + MARKET_TRADE = 2 + + def __init__(self, clientId, key, secret): + self.__clientId = clientId + self.__key = key + self.__secret = secret + self.__nonce = NonceGenerator() + self.__lock = threading.Lock() + + def _buildQuery(self, params): + # Build the signature. + nonce = self.__nonce.getNext() + message = "%d%s%s" % (nonce, self.__clientId, self.__key) + signature = hmac.new(self.__secret, msg=message, digestmod=hashlib.sha256).hexdigest().upper() + + # Headers + headers = {} + headers["User-Agent"] = HTTPClient.USER_AGENT + + # POST data. + data = {} + data.update(params) + data["key"] = self.__key + data["signature"] = signature + data["nonce"] = nonce + + return (data, headers) + + def _post(self, url, params): + common.logger.debug("POST to %s with params %s" % (url, str(params))) + + # Serialize access to nonce generation and http requests to avoid + # sending them in the wrong order. + with self.__lock: + data, headers = self._buildQuery(params) + response = requests.post(url, headers=headers, data=data, timeout=HTTPClient.REQUEST_TIMEOUT) + response.raise_for_status() + + jsonResponse = response.json() + + # Check for errors. + if isinstance(jsonResponse, dict): + error = jsonResponse.get("error") + if error is not None: + raise Exception(error) + + return jsonResponse + + def getAccountBalance(self): + url = "https://www.bitstamp.net/api/balance/" + jsonResponse = self._post(url, {}) + return AccountBalance(jsonResponse) + + def getOpenOrders(self): + url = "https://www.bitstamp.net/api/open_orders/" + jsonResponse = self._post(url, {}) + return [Order(json_open_order) for json_open_order in jsonResponse] + + def cancelOrder(self, orderId): + url = "https://www.bitstamp.net/api/cancel_order/" + params = {"id": orderId} + jsonResponse = self._post(url, params) + if jsonResponse != True: + raise Exception("Failed to cancel order") + + def buyLimit(self, limitPrice, quantity): + url = "https://www.bitstamp.net/api/buy/" + + # Rounding price to avoid 'Ensure that there are no more than 2 decimal places' + # error. + price = round(limitPrice, 2) + # Rounding amount to avoid 'Ensure that there are no more than 8 decimal places' + # error. + amount = round(quantity, 8) + + params = { + "price": price, + "amount": amount + } + jsonResponse = self._post(url, params) + return Order(jsonResponse) + + def sellLimit(self, limitPrice, quantity): + url = "https://www.bitstamp.net/api/sell/" + + # Rounding price to avoid 'Ensure that there are no more than 2 decimal places' + # error. + price = round(limitPrice, 2) + # Rounding amount to avoid 'Ensure that there are no more than 8 decimal places' + # error. + amount = round(quantity, 8) + + params = { + "price": price, + "amount": amount + } + jsonResponse = self._post(url, params) + return Order(jsonResponse) + + def getUserTransactions(self, transactionType=None): + url = "https://www.bitstamp.net/api/user_transactions/" + jsonResponse = self._post(url, {}) + if transactionType is not None: + jsonUserTransactions = filter( + lambda jsonUserTransaction: jsonUserTransaction["type"] == transactionType, jsonResponse + ) + else: + jsonUserTransactions = jsonResponse + return [UserTransaction(jsonUserTransaction) for jsonUserTransaction in jsonUserTransactions] diff --git a/pyalgotrade/alpaca/livebroker.py b/pyalgotrade/alpaca/livebroker.py new file mode 100644 index 000000000..6bb8edf7e --- /dev/null +++ b/pyalgotrade/alpaca/livebroker.py @@ -0,0 +1,332 @@ +# PyAlgoTrade +# +# Copyright 2011-2018 Gabriel Martin Becedillas Ruiz +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +.. moduleauthor:: Gabriel Martin Becedillas Ruiz +""" + +import threading +import time + +from six.moves import queue + +from pyalgotrade import broker +from pyalgotrade.bitstamp import httpclient +from pyalgotrade.bitstamp import common + + +def build_order_from_open_order(openOrder, instrumentTraits): + if openOrder.isBuy(): + action = broker.Order.Action.BUY + elif openOrder.isSell(): + action = broker.Order.Action.SELL + else: + raise Exception("Invalid order type") + + ret = broker.LimitOrder(action, common.btc_symbol, openOrder.getPrice(), openOrder.getAmount(), instrumentTraits) + ret.setSubmitted(openOrder.getId(), openOrder.getDateTime()) + ret.setState(broker.Order.State.ACCEPTED) + return ret + + +class TradeMonitor(threading.Thread): + POLL_FREQUENCY = 2 + + # Events + ON_USER_TRADE = 1 + + def __init__(self, httpClient): + super(TradeMonitor, self).__init__() + self.__lastTradeId = -1 + self.__httpClient = httpClient + self.__queue = queue.Queue() + self.__stop = False + + def _getNewTrades(self): + userTrades = self.__httpClient.getUserTransactions(httpclient.HTTPClient.UserTransactionType.MARKET_TRADE) + + # Get the new trades only. + ret = [t for t in userTrades if t.getId() > self.__lastTradeId] + + # Sort by id, so older trades first. + return sorted(ret, key=lambda t: t.getId()) + + def getQueue(self): + return self.__queue + + def start(self): + trades = self._getNewTrades() + # Store the last trade id since we'll start processing new ones only. + if len(trades): + self.__lastTradeId = trades[-1].getId() + common.logger.info("Last trade found: %d" % (self.__lastTradeId)) + + super(TradeMonitor, self).start() + + def run(self): + while not self.__stop: + try: + trades = self._getNewTrades() + if len(trades): + self.__lastTradeId = trades[-1].getId() + common.logger.info("%d new trade/s found" % (len(trades))) + self.__queue.put((TradeMonitor.ON_USER_TRADE, trades)) + except Exception as e: + common.logger.critical("Error retrieving user transactions", exc_info=e) + + time.sleep(TradeMonitor.POLL_FREQUENCY) + + def stop(self): + self.__stop = True + + +class LiveBroker(broker.Broker): + """A Bitstamp live broker. + + :param clientId: Client id. + :type clientId: string. + :param key: API key. + :type key: string. + :param secret: API secret. + :type secret: string. + + + .. note:: + * Only limit orders are supported. + * Orders are automatically set as **goodTillCanceled=True** and **allOrNone=False**. + * BUY_TO_COVER orders are mapped to BUY orders. + * SELL_SHORT orders are mapped to SELL orders. + * API access permissions should include: + + * Account balance + * Open orders + * Buy limit order + * User transactions + * Cancel order + * Sell limit order + """ + + QUEUE_TIMEOUT = 0.01 + + def __init__(self, clientId, key, secret): + super(LiveBroker, self).__init__() + self.__stop = False + self.__httpClient = self.buildHTTPClient(clientId, key, secret) + self.__tradeMonitor = TradeMonitor(self.__httpClient) + self.__cash = 0 + self.__shares = {} + self.__activeOrders = {} + + def _registerOrder(self, order): + assert(order.getId() not in self.__activeOrders) + assert(order.getId() is not None) + self.__activeOrders[order.getId()] = order + + def _unregisterOrder(self, order): + assert(order.getId() in self.__activeOrders) + assert(order.getId() is not None) + del self.__activeOrders[order.getId()] + + # Factory method for testing purposes. + def buildHTTPClient(self, clientId, key, secret): + return httpclient.HTTPClient(clientId, key, secret) + + def refreshAccountBalance(self): + """Refreshes cash and BTC balance.""" + + self.__stop = True # Stop running in case of errors. + common.logger.info("Retrieving account balance.") + balance = self.__httpClient.getAccountBalance() + + # Cash + self.__cash = round(balance.getUSDAvailable(), 2) + common.logger.info("%s USD" % (self.__cash)) + # BTC + btc = balance.getBTCAvailable() + if btc: + self.__shares = {common.btc_symbol: btc} + else: + self.__shares = {} + common.logger.info("%s BTC" % (btc)) + + self.__stop = False # No errors. Keep running. + + def refreshOpenOrders(self): + self.__stop = True # Stop running in case of errors. + common.logger.info("Retrieving open orders.") + openOrders = self.__httpClient.getOpenOrders() + for openOrder in openOrders: + self._registerOrder(build_order_from_open_order(openOrder, self.getInstrumentTraits(common.btc_symbol))) + + common.logger.info("%d open order/s found" % (len(openOrders))) + self.__stop = False # No errors. Keep running. + + def _startTradeMonitor(self): + self.__stop = True # Stop running in case of errors. + common.logger.info("Initializing trade monitor.") + self.__tradeMonitor.start() + self.__stop = False # No errors. Keep running. + + def _onUserTrades(self, trades): + for trade in trades: + order = self.__activeOrders.get(trade.getOrderId()) + if order is not None: + fee = trade.getFee() + fillPrice = trade.getBTCUSD() + btcAmount = trade.getBTC() + dateTime = trade.getDateTime() + + # Update cash and shares. + self.refreshAccountBalance() + # Update the order. + orderExecutionInfo = broker.OrderExecutionInfo(fillPrice, abs(btcAmount), fee, dateTime) + order.addExecutionInfo(orderExecutionInfo) + if not order.isActive(): + self._unregisterOrder(order) + # Notify that the order was updated. + if order.isFilled(): + eventType = broker.OrderEvent.Type.FILLED + else: + eventType = broker.OrderEvent.Type.PARTIALLY_FILLED + self.notifyOrderEvent(broker.OrderEvent(order, eventType, orderExecutionInfo)) + else: + common.logger.info("Trade %d refered to order %d that is not active" % (trade.getId(), trade.getOrderId())) + + # BEGIN observer.Subject interface + def start(self): + super(LiveBroker, self).start() + self.refreshAccountBalance() + self.refreshOpenOrders() + self._startTradeMonitor() + + def stop(self): + self.__stop = True + common.logger.info("Shutting down trade monitor.") + self.__tradeMonitor.stop() + + def join(self): + if self.__tradeMonitor.isAlive(): + self.__tradeMonitor.join() + + def eof(self): + return self.__stop + + def dispatch(self): + # Switch orders from SUBMITTED to ACCEPTED. + ordersToProcess = list(self.__activeOrders.values()) + for order in ordersToProcess: + if order.isSubmitted(): + order.switchState(broker.Order.State.ACCEPTED) + self.notifyOrderEvent(broker.OrderEvent(order, broker.OrderEvent.Type.ACCEPTED, None)) + + # Dispatch events from the trade monitor. + try: + eventType, eventData = self.__tradeMonitor.getQueue().get(True, LiveBroker.QUEUE_TIMEOUT) + + if eventType == TradeMonitor.ON_USER_TRADE: + self._onUserTrades(eventData) + else: + common.logger.error("Invalid event received to dispatch: %s - %s" % (eventType, eventData)) + except queue.Empty: + pass + + def peekDateTime(self): + # Return None since this is a realtime subject. + return None + + # END observer.Subject interface + + # BEGIN broker.Broker interface + + def getCash(self, includeShort=True): + return self.__cash + + def getInstrumentTraits(self, instrument): + return common.BTCTraits() + + def getShares(self, instrument): + return self.__shares.get(instrument, 0) + + def getPositions(self): + return self.__shares + + def getActiveOrders(self, instrument=None): + return list(self.__activeOrders.values()) + + def submitOrder(self, order): + if order.isInitial(): + # Override user settings based on Bitstamp limitations. + order.setAllOrNone(False) + order.setGoodTillCanceled(True) + + if order.isBuy(): + bitstampOrder = self.__httpClient.buyLimit(order.getLimitPrice(), order.getQuantity()) + else: + bitstampOrder = self.__httpClient.sellLimit(order.getLimitPrice(), order.getQuantity()) + + order.setSubmitted(bitstampOrder.getId(), bitstampOrder.getDateTime()) + self._registerOrder(order) + # Switch from INITIAL -> SUBMITTED + # IMPORTANT: Do not emit an event for this switch because when using the position interface + # the order is not yet mapped to the position and Position.onOrderUpdated will get called. + order.switchState(broker.Order.State.SUBMITTED) + else: + raise Exception("The order was already processed") + + def createMarketOrder(self, action, instrument, quantity, onClose=False): + raise Exception("Market orders are not supported") + + def createLimitOrder(self, action, instrument, limitPrice, quantity): + if instrument != common.btc_symbol: + raise Exception("Only BTC instrument is supported") + + if action == broker.Order.Action.BUY_TO_COVER: + action = broker.Order.Action.BUY + elif action == broker.Order.Action.SELL_SHORT: + action = broker.Order.Action.SELL + + if action not in [broker.Order.Action.BUY, broker.Order.Action.SELL]: + raise Exception("Only BUY/SELL orders are supported") + + instrumentTraits = self.getInstrumentTraits(instrument) + limitPrice = round(limitPrice, 2) + quantity = instrumentTraits.roundQuantity(quantity) + return broker.LimitOrder(action, instrument, limitPrice, quantity, instrumentTraits) + + def createStopOrder(self, action, instrument, stopPrice, quantity): + raise Exception("Stop orders are not supported") + + def createStopLimitOrder(self, action, instrument, stopPrice, limitPrice, quantity): + raise Exception("Stop limit orders are not supported") + + def cancelOrder(self, order): + activeOrder = self.__activeOrders.get(order.getId()) + if activeOrder is None: + raise Exception("The order is not active anymore") + if activeOrder.isFilled(): + raise Exception("Can't cancel order that has already been filled") + + self.__httpClient.cancelOrder(order.getId()) + self._unregisterOrder(order) + order.switchState(broker.Order.State.CANCELED) + + # Update cash and shares. + self.refreshAccountBalance() + + # Notify that the order was canceled. + self.notifyOrderEvent(broker.OrderEvent(order, broker.OrderEvent.Type.CANCELED, "User requested cancellation")) + + # END broker.Broker interface diff --git a/pyalgotrade/alpaca/livefeed.py b/pyalgotrade/alpaca/livefeed.py new file mode 100644 index 000000000..f874f06b6 --- /dev/null +++ b/pyalgotrade/alpaca/livefeed.py @@ -0,0 +1,267 @@ +# PyAlgoTrade +# +# Copyright 2011-2018 Gabriel Martin Becedillas Ruiz +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +.. moduleauthor:: Gabriel Martin Becedillas Ruiz +""" + +import datetime +import time + +from six.moves import queue + +from pyalgotrade import bar +from pyalgotrade import barfeed +from pyalgotrade import observer +from pyalgotrade.bitstamp import common +from pyalgotrade.bitstamp import wsclient + + +class TradeBar(bar.Bar): + # Optimization to reduce memory footprint. + __slots__ = ('__dateTime', '__tradeId', '__price', '__amount') + + def __init__(self, dateTime, trade): + self.__dateTime = dateTime + self.__tradeId = trade.getId() + self.__price = trade.getPrice() + self.__amount = trade.getAmount() + self.__buy = trade.isBuy() + + def __setstate__(self, state): + (self.__dateTime, self.__tradeId, self.__price, self.__amount) = state + + def __getstate__(self): + return (self.__dateTime, self.__tradeId, self.__price, self.__amount) + + def setUseAdjustedValue(self, useAdjusted): + if useAdjusted: + raise Exception("Adjusted close is not available") + + def getTradeId(self): + return self.__tradeId + + def getFrequency(self): + return bar.Frequency.TRADE + + def getDateTime(self): + return self.__dateTime + + def getOpen(self, adjusted=False): + return self.__price + + def getHigh(self, adjusted=False): + return self.__price + + def getLow(self, adjusted=False): + return self.__price + + def getClose(self, adjusted=False): + return self.__price + + def getVolume(self): + return self.__amount + + def getAdjClose(self): + return None + + def getTypicalPrice(self): + return self.__price + + def getPrice(self): + return self.__price + + def getUseAdjValue(self): + return False + + def isBuy(self): + return self.__buy + + def isSell(self): + return not self.__buy + + +class LiveTradeFeed(barfeed.BaseBarFeed): + + """A real-time BarFeed that builds bars from live trades. + + :param maxLen: The maximum number of values that the :class:`pyalgotrade.dataseries.bards.BarDataSeries` will hold. + Once a bounded length is full, when new items are added, a corresponding number of items are discarded + from the opposite end. If None then dataseries.DEFAULT_MAX_LEN is used. + :type maxLen: int. + + .. note:: + Note that a Bar will be created for every trade, so open, high, low and close values will all be the same. + """ + + QUEUE_TIMEOUT = 0.01 + + def __init__(self, maxLen=None): + super(LiveTradeFeed, self).__init__(bar.Frequency.TRADE, maxLen) + self.__barDicts = [] + self.registerInstrument(common.btc_symbol) + self.__prevTradeDateTime = None + self.__thread = None + self.__wsClientConnected = False + self.__enableReconnection = True + self.__stopped = False + self.__orderBookUpdateEvent = observer.Event() + + # Factory method for testing purposes. + def buildWebSocketClientThread(self): + return wsclient.WebSocketClientThread() + + def getCurrentDateTime(self): + return wsclient.get_current_datetime() + + def enableReconection(self, enableReconnection): + self.__enableReconnection = enableReconnection + + def __initializeClient(self): + common.logger.info("Initializing websocket client.") + assert self.__wsClientConnected is False, "Websocket client already connected" + + try: + # Start the thread that runs the client. + self.__thread = self.buildWebSocketClientThread() + self.__thread.start() + except Exception as e: + common.logger.exception("Error connecting : %s" % str(e)) + + # Wait for initialization to complete. + while not self.__wsClientConnected and self.__thread.is_alive(): + self.__dispatchImpl([wsclient.WebSocketClient.Event.CONNECTED]) + + if self.__wsClientConnected: + common.logger.info("Initialization ok.") + else: + common.logger.error("Initialization failed.") + return self.__wsClientConnected + + def __onConnected(self): + self.__wsClientConnected = True + + def __onDisconnected(self): + self.__wsClientConnected = False + + if self.__enableReconnection: + initialized = False + while not self.__stopped and not initialized: + common.logger.info("Reconnecting") + initialized = self.__initializeClient() + if not initialized: + time.sleep(5) + else: + self.__stopped = True + + def __dispatchImpl(self, eventFilter): + ret = False + try: + eventType, eventData = self.__thread.getQueue().get(True, LiveTradeFeed.QUEUE_TIMEOUT) + if eventFilter is not None and eventType not in eventFilter: + return False + + ret = True + if eventType == wsclient.WebSocketClient.Event.TRADE: + self.__onTrade(eventData) + elif eventType == wsclient.WebSocketClient.Event.ORDER_BOOK_UPDATE: + self.__orderBookUpdateEvent.emit(eventData) + elif eventType == wsclient.WebSocketClient.Event.CONNECTED: + self.__onConnected() + elif eventType == wsclient.WebSocketClient.Event.DISCONNECTED: + self.__onDisconnected() + else: + ret = False + common.logger.error("Invalid event received to dispatch: %s - %s" % (eventType, eventData)) + except queue.Empty: + pass + return ret + + # Bar datetimes should not duplicate. In case trade object datetimes conflict, we just move one slightly forward. + def __getTradeDateTime(self, trade): + ret = trade.getDateTime() + if ret == self.__prevTradeDateTime: + ret += datetime.timedelta(microseconds=1) + self.__prevTradeDateTime = ret + return ret + + def __onTrade(self, trade): + # Build a bar for each trade. + barDict = { + common.btc_symbol: TradeBar(self.__getTradeDateTime(trade), trade) + } + self.__barDicts.append(barDict) + + def barsHaveAdjClose(self): + return False + + def getNextBars(self): + ret = None + if len(self.__barDicts): + ret = bar.Bars(self.__barDicts.pop(0)) + return ret + + def peekDateTime(self): + # Return None since this is a realtime subject. + return None + + # This may raise. + def start(self): + super(LiveTradeFeed, self).start() + if self.__thread is not None: + raise Exception("Already running") + elif not self.__initializeClient(): + self.__stopped = True + raise Exception("Initialization failed") + + def dispatch(self): + # Note that we may return True even if we didn't dispatch any Bar + # event. + ret = False + if self.__dispatchImpl(None): + ret = True + if super(LiveTradeFeed, self).dispatch(): + ret = True + return ret + + # This should not raise. + def stop(self): + try: + self.__stopped = True + if self.__thread is not None and self.__thread.is_alive(): + common.logger.info("Shutting down websocket client.") + self.__thread.stop() + except Exception as e: + common.logger.error("Error shutting down client: %s" % (str(e))) + + # This should not raise. + def join(self): + if self.__thread is not None: + self.__thread.join() + + def eof(self): + return self.__stopped + + def getOrderBookUpdateEvent(self): + """ + Returns the event that will be emitted when the orderbook gets updated. + + Eventh handlers should receive one parameter: + 1. A :class:`pyalgotrade.bitstamp.wsclient.OrderBookUpdate` instance. + + :rtype: :class:`pyalgotrade.observer.Event`. + """ + return self.__orderBookUpdateEvent diff --git a/pyalgotrade/alpaca/wsclient.py b/pyalgotrade/alpaca/wsclient.py new file mode 100644 index 000000000..0e7a4c67a --- /dev/null +++ b/pyalgotrade/alpaca/wsclient.py @@ -0,0 +1,198 @@ +# PyAlgoTrade +# +# Copyright 2011-2018 Gabriel Martin Becedillas Ruiz +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +.. moduleauthor:: Gabriel Martin Becedillas Ruiz +""" + +import datetime + +from six.moves import queue + +from pyalgotrade.websocket import pusher +from pyalgotrade.websocket import client +from pyalgotrade.bitstamp import common + + +def get_current_datetime(): + return datetime.datetime.now() + +# Bitstamp protocol reference: https://www.bitstamp.net/websocket/ + + +class Trade(pusher.Event): + """A trade event.""" + + def __init__(self, dateTime, eventDict): + super(Trade, self).__init__(eventDict, True) + self.__dateTime = dateTime + + def getDateTime(self): + """Returns the :class:`datetime.datetime` when this event was received.""" + return self.__dateTime + + def getId(self): + """Returns the trade id.""" + return self.getData()["id"] + + def getPrice(self): + """Returns the trade price.""" + return self.getData()["price"] + + def getAmount(self): + """Returns the trade amount.""" + return self.getData()["amount"] + + def isBuy(self): + """Returns True if the trade was a buy.""" + return self.getData()["type"] == 0 + + def isSell(self): + """Returns True if the trade was a sell.""" + return self.getData()["type"] == 1 + + +class OrderBookUpdate(pusher.Event): + """An order book update event.""" + + def __init__(self, dateTime, eventDict): + super(OrderBookUpdate, self).__init__(eventDict, True) + self.__dateTime = dateTime + + def getDateTime(self): + """Returns the :class:`datetime.datetime` when this event was received.""" + return self.__dateTime + + def getBidPrices(self): + """Returns a list with the top 20 bid prices.""" + return [float(bid[0]) for bid in self.getData()["bids"]] + + def getBidVolumes(self): + """Returns a list with the top 20 bid volumes.""" + return [float(bid[1]) for bid in self.getData()["bids"]] + + def getAskPrices(self): + """Returns a list with the top 20 ask prices.""" + return [float(ask[0]) for ask in self.getData()["asks"]] + + def getAskVolumes(self): + """Returns a list with the top 20 ask volumes.""" + return [float(ask[1]) for ask in self.getData()["asks"]] + + +class WebSocketClient(pusher.WebSocketClient): + """ + This websocket client class is designed to be running in a separate thread and for that reason + events are pushed into a queue. + """ + + PUSHER_APP_KEY = "de504dc5763aeef9ff52" + + class Event: + TRADE = 1 + ORDER_BOOK_UPDATE = 2 + CONNECTED = 3 + DISCONNECTED = 4 + + def __init__(self, queue): + super(WebSocketClient, self).__init__(WebSocketClient.PUSHER_APP_KEY, 5) + self.__queue = queue + + def onMessage(self, msg): + # If we can't handle the message, forward it to Pusher WebSocketClient. + event = msg.get("event") + if event == "trade": + self.onTrade(Trade(get_current_datetime(), msg)) + elif event == "data" and msg.get("channel") == "order_book": + self.onOrderBookUpdate(OrderBookUpdate(get_current_datetime(), msg)) + else: + super(WebSocketClient, self).onMessage(msg) + + ###################################################################### + # WebSocketClientBase events. + + def onClosed(self, code, reason): + common.logger.info("Closed. Code: %s. Reason: %s." % (code, reason)) + self.__queue.put((WebSocketClient.Event.DISCONNECTED, None)) + + def onDisconnectionDetected(self): + common.logger.warning("Disconnection detected.") + try: + self.stopClient() + except Exception as e: + common.logger.error("Error stopping websocket client: %s." % (str(e))) + self.__queue.put((WebSocketClient.Event.DISCONNECTED, None)) + + ###################################################################### + # Pusher specific events. + + def onConnectionEstablished(self, event): + common.logger.info("Connection established.") + self.__queue.put((WebSocketClient.Event.CONNECTED, None)) + + channels = ["live_trades", "order_book"] + common.logger.info("Subscribing to channels %s." % channels) + for channel in channels: + self.subscribeChannel(channel) + + def onError(self, event): + common.logger.error("Error: %s" % (event)) + + def onUnknownEvent(self, event): + common.logger.warning("Unknown event: %s" % (event)) + + ###################################################################### + # Bitstamp specific + + def onTrade(self, trade): + self.__queue.put((WebSocketClient.Event.TRADE, trade)) + + def onOrderBookUpdate(self, orderBookUpdate): + self.__queue.put((WebSocketClient.Event.ORDER_BOOK_UPDATE, orderBookUpdate)) + + +class WebSocketClientThread(client.WebSocketClientThreadBase): + """ + This thread class is responsible for running a WebSocketClient. + """ + + def __init__(self): + super(WebSocketClientThread, self).__init__() + self.__queue = queue.Queue() + self.__wsClient = None + + def getQueue(self): + return self.__queue + + def run(self): + super(WebSocketClientThread, self).run() + + # We create the WebSocketClient right in the thread, instead of doing so in the constructor, + # because it has thread affinity. + try: + self.__wsClient = WebSocketClient(self.__queue) + self.__wsClient.connect() + self.__wsClient.startClient() + except Exception: + common.logger.exception("Failed to connect: %s") + + def stop(self): + try: + if self.__wsClient is not None: + common.logger.info("Stopping websocket client.") + self.__wsClient.stopClient() + except Exception as e: + common.logger.error("Error stopping websocket client: %s." % (str(e))) From 283c0aff16efb5a87b494cc06390a6049e93aec8 Mon Sep 17 00:00:00 2001 From: jibi Date: Fri, 17 Dec 2021 22:55:04 -0800 Subject: [PATCH 3/7] work on alpaca restfeed. --- pyalgotrade/alpaca/barfeed.py | 25 ---- pyalgotrade/alpaca/common.py | 10 +- pyalgotrade/alpaca/httpclient.py | 4 +- pyalgotrade/alpaca/restfeed.py | 206 +++++++++++++++++++++++++++++++ 4 files changed, 213 insertions(+), 32 deletions(-) delete mode 100644 pyalgotrade/alpaca/barfeed.py create mode 100644 pyalgotrade/alpaca/restfeed.py diff --git a/pyalgotrade/alpaca/barfeed.py b/pyalgotrade/alpaca/barfeed.py deleted file mode 100644 index 4abd536b3..000000000 --- a/pyalgotrade/alpaca/barfeed.py +++ /dev/null @@ -1,25 +0,0 @@ -# PyAlgoTrade -# -# Copyright 2011-2018 Gabriel Martin Becedillas Ruiz -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -.. moduleauthor:: Gabriel Martin Becedillas Ruiz -""" - - -from pyalgotrade.bitstamp import livefeed - - -LiveTradeFeed = livefeed.LiveTradeFeed diff --git a/pyalgotrade/alpaca/common.py b/pyalgotrade/alpaca/common.py index b42b87bf5..1d44d59bc 100644 --- a/pyalgotrade/alpaca/common.py +++ b/pyalgotrade/alpaca/common.py @@ -22,10 +22,10 @@ from pyalgotrade import broker -logger = pyalgotrade.logger.getLogger("bitstamp") -btc_symbol = "BTC" +logger = pyalgotrade.logger.getLogger("alpaca") +# btc_symbol = "BTC" -class BTCTraits(broker.InstrumentTraits): - def roundQuantity(self, quantity): - return round(quantity, 8) +# class BTCTraits(broker.InstrumentTraits): +# def roundQuantity(self, quantity): +# return round(quantity, 8) diff --git a/pyalgotrade/alpaca/httpclient.py b/pyalgotrade/alpaca/httpclient.py index e51f4e8da..be49fe73e 100644 --- a/pyalgotrade/alpaca/httpclient.py +++ b/pyalgotrade/alpaca/httpclient.py @@ -15,7 +15,7 @@ # limitations under the License. """ -.. moduleauthor:: Gabriel Martin Becedillas Ruiz +.. moduleauthor:: Robert Lee """ import time @@ -26,7 +26,7 @@ import threading from pyalgotrade.utils import dt -from pyalgotrade.bitstamp import common +from pyalgotrade.alpaca import common import logging logging.getLogger("requests").setLevel(logging.ERROR) diff --git a/pyalgotrade/alpaca/restfeed.py b/pyalgotrade/alpaca/restfeed.py new file mode 100644 index 000000000..338900a37 --- /dev/null +++ b/pyalgotrade/alpaca/restfeed.py @@ -0,0 +1,206 @@ +# PyAlgoTrade +# +# Copyright 2011-2018 Gabriel Martin Becedillas Ruiz +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +.. moduleauthor:: Robert Lee +https://github.com/alpacahq/alpaca-trade-api-python/blob/master/examples/historic_async.py +""" + + +from enum import Enum +import datetime +import os +import sys +import asyncio +import argparse + +import pandas as pd + +import alpaca_trade_api as tradeapi +from alpaca_trade_api.rest import TimeFrame, URL +from alpaca_trade_api.rest_async import gather_with_concurrency, AsyncRest + +from pyalgotrade.alpaca import common +# from pyalgotrade.alpaca import livefeed + +# LiveTradeFeed = livefeed.LiveTradeFeed + + +NY = 'America/New_York' + +class DataType(str, Enum): + Bars = "Bars" + Trades = "Trades" + Quotes = "Quotes" + + +def get_data_method(data_type: DataType): + if data_type == DataType.Bars: + return rest.get_bars_async + elif data_type == DataType.Trades: + return rest.get_trades_async + elif data_type == DataType.Quotes: + return rest.get_quotes_async + else: + raise Exception(f"Unsupoported data type: {data_type}") + + +async def get_historic_data_base(symbols, data_type: DataType, start, end, + timeframe: TimeFrame = None): + """ + base function to use with all + :param symbols: + :param start: + :param end: + :param timeframe: + :return: + """ + # Check Python version + major = sys.version_info.major + minor = sys.version_info.minor + if major < 3 or minor < 6: + raise Exception('asyncio is not support in your python version') + msg = f"Getting {data_type} data for {len(symbols)} symbols" + msg += f", timeframe: {timeframe}" if timeframe else "" + msg += f" between dates: start={start}, end={end}" + common.logger.info(msg) + + # loop through 1000 symbols at a time + step_size = 1000 + results = [] + for i in range(0, len(symbols), step_size): + tasks = [] + for symbol in symbols[i:i+step_size]: + args = [symbol, start, end, timeframe.value] if timeframe else \ + [symbol, start, end] + tasks.append(get_data_method(data_type)(*args)) + + if minor >= 8: + results.extend(await asyncio.gather(*tasks, return_exceptions=True)) + else: + results.extend(await gather_with_concurrency(500, *tasks)) + + bad_requests = 0 + for response in results: + if isinstance(response, Exception): + common.logger.error(f"Got an error: {response}") + elif not len(response[1]): + bad_requests += 1 + + common.logger.info(f"Total of {len(results)} {data_type}, and {bad_requests} " + f"empty responses.") + + return results + + +# async def get_historic_bars(symbols, start, end, timeframe: TimeFrame): +# await get_historic_data_base(symbols, DataType.Bars, start, end, timeframe) + + +# async def get_historic_trades(symbols, start, end, timeframe: TimeFrame): +# await get_historic_data_base(symbols, DataType.Trades, start, end) + + +# async def get_historic_quotes(symbols, start, end, timeframe: TimeFrame): +# await get_historic_data_base(symbols, DataType.Quotes, start, end) + + +# async def main(symbols, start_time, end_time, timeframe): +# start = pd.Timestamp(start_time, tz=NY).date().isoformat() +# end = pd.Timestamp(end_time, tz=NY).date().isoformat() + + + +# # await get_historic_bars(symbols, start, end, timeframe) +# # await get_historic_trades(symbols, start, end, timeframe) +# # await get_historic_quotes(symbols, start, end, timeframe) + + +if __name__ == '__main__': + + # Get parameters + parser = argparse.ArgumentParser(description="Alpaca Rest Datafeed") + + # data request + parser.add_argument("--symbols", required = True, nargs = '+', + help = "One or more symbols for which to download data.") + parser.add_argument("--start-date", required=True, + type=str, help="Start date of data.") + parser.add_argument("--end-date", required=True, + type=str, help="End date of data.") + parser.add_argument("--datatype", required = False, default="bars", + choices = ['bars', 'trades', 'quotes'], + help="The type of data to request. One of bars, trades, or quotes.") + parser.add_argument("--timeframe", required = False, default="1Day", + help="The frequency of the bars, in format [n]Min, [n]Hour, or [n]Day.") + # credentials + parser.add_argument("--api-key-id", required=False, + help="Alpaca Key ID if it is not saved as an environment variable.") + parser.add_argument("--api-secret-key", required=False, + help="Alpaca secret key if it is not saved as an environment variable.") + # storage + parser.add_argument("--storage", required=True, + help="The path were the files will be downloaded to") + # parser.add_argument("--force-download", action='store_false', + # help="Force downloading even if the files exist") + + # Set up variables + args = parser.parse_args() + + # credentials + api_key_id = args.api_key_id or os.environ.get('ALPACA_API_KEY_ID') + api_secret_key = args.api_secret_key or os.environ.get('ALPACA_API_SECRET_KEY') + # data request + symbols = args.symbols + start_date = pd.Timestamp(args.start_date, tz=NY).date().isoformat() + end_date = pd.Timestamp(args.end_date, tz=NY).date().isoformat() + if args.datatype == 'bars': + datatype = DataType.Bars + elif args.datatype == 'trades': + datatype = DataType.Trades + elif args.datatype == 'quotes': + datatype = DataType.Quotes + timeframe = args.timeframe + # storage + if not os.path.exists(args.storage): + common.logger.info("Creating %s directory" % (args.storage)) + os.mkdir(args.storage) + storage = args.storage + + + # Make connection + base_url = "https://paper-api.alpaca.markets" + rest = AsyncRest(key_id=api_key_id, + secret_key=api_secret_key) + feed = "sip" # change to "iex" if only free account + + api = tradeapi.REST(key_id=api_key_id, + secret_key=api_secret_key, + base_url=URL(base_url)) + + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + get_historic_data_base(symbols, datatype, start_date, end_date, timeframe) + ) + # f = open(storage, "w") + # f.write(bars) + # f.close() + + + +# TODO +# split out function for non-command line use +# test functions \ No newline at end of file From 37cedda99df1b2b7be916b07545bbe7c7da19a71 Mon Sep 17 00:00:00 2001 From: jibi Date: Thu, 23 Dec 2021 12:00:59 -0800 Subject: [PATCH 4/7] Added make_async_connection in alpaca.common and get_historic_data in alpaca.restfeed --- pyalgotrade/alpaca/common.py | 39 ++++++++ pyalgotrade/alpaca/restfeed.py | 169 ++++++++++++++++----------------- 2 files changed, 120 insertions(+), 88 deletions(-) diff --git a/pyalgotrade/alpaca/common.py b/pyalgotrade/alpaca/common.py index 1d44d59bc..c7d48840b 100644 --- a/pyalgotrade/alpaca/common.py +++ b/pyalgotrade/alpaca/common.py @@ -17,12 +17,51 @@ """ .. moduleauthor:: Gabriel Martin Becedillas Ruiz """ +from enum import Enum +import os + +from alpaca_trade_api.rest_async import AsyncRest import pyalgotrade.logger from pyalgotrade import broker logger = pyalgotrade.logger.getLogger("alpaca") + + +def make_async_rest_connection(api_key_id = None, api_secret_key = None): + + # credentials + api_key_id = api_key_id or os.environ.get('ALPACA_API_KEY_ID_PAPER') + api_secret_key = api_secret_key or os.environ.get('ALPACA_API_SECRET_KEY_PAPER') + + if api_key_id is None: + logger.error('Unable to retrieve API Key ID.') + if api_key_id is None: + logger.error('Unable to retrieve API Secret Key.') + + rest = AsyncRest(key_id=api_key_id, + secret_key=api_secret_key) + + return rest + + + + + + + + + + + + + + + + + + # btc_symbol = "BTC" diff --git a/pyalgotrade/alpaca/restfeed.py b/pyalgotrade/alpaca/restfeed.py index 338900a37..0adbce493 100644 --- a/pyalgotrade/alpaca/restfeed.py +++ b/pyalgotrade/alpaca/restfeed.py @@ -17,6 +17,15 @@ """ .. moduleauthor:: Robert Lee https://github.com/alpacahq/alpaca-trade-api-python/blob/master/examples/historic_async.py + +Example usage: + from pyalgotrade.alpaca.common import make_async_rest_connection + from pyalgotrade.alpaca.restfeed import get_historic_data + + async_rest = make_async_rest_connection(api_key_id, api_secret_key) + results = get_historic_data(async_rest, ['AAPL', 'IBM'], '2021-01-01', '2021-01-10, 'QUOTES') + + """ @@ -41,32 +50,29 @@ NY = 'America/New_York' -class DataType(str, Enum): - Bars = "Bars" - Trades = "Trades" - Quotes = "Quotes" - - -def get_data_method(data_type: DataType): - if data_type == DataType.Bars: - return rest.get_bars_async - elif data_type == DataType.Trades: - return rest.get_trades_async - elif data_type == DataType.Quotes: - return rest.get_quotes_async - else: - raise Exception(f"Unsupoported data type: {data_type}") - - -async def get_historic_data_base(symbols, data_type: DataType, start, end, - timeframe: TimeFrame = None): +async def get_historic_data(async_rest, symbols, start_date, end_date, + data_type = 'BARS', timeframe = '1Day'): """ - base function to use with all - :param symbols: - :param start: - :param end: - :param timeframe: - :return: + Retrieve historic data for multiple symbols using Alpaca's get_[datatype]_async + from the AsyncRest object. + + Args: + async_rest (Alpaca AsyncRest object): See alpaca_trade_api.rest_async.AsyncRest. + symbols (list): A list of symbols for which to get data. + start_date (str): Start date of time period of data request. + end_date (str): End date of time period of data request. + data_type (str, optional): One of 'BARS', 'TRADES', or 'QUOTES'. Defaults to 'BARS'. + timeframe (str): Frequency of data requested. Format as [amount][unit], + where [amount]is an integer, and [unit] is one of Min, Hour, or Day. Defaults to 1Day. + Ignored if data_type is not 'BARS'. + + Returns: + [(symbol, df),]: List of tuples of (symbol, pandas DataFrame) + + Usage: + async_rest = make_async_rest_connetion() + symbols, dfs = + """ # Check Python version major = sys.version_info.major @@ -75,24 +81,44 @@ async def get_historic_data_base(symbols, data_type: DataType, start, end, raise Exception('asyncio is not support in your python version') msg = f"Getting {data_type} data for {len(symbols)} symbols" msg += f", timeframe: {timeframe}" if timeframe else "" - msg += f" between dates: start={start}, end={end}" + msg += f" between dates: start={start_date}, end={end_date}" common.logger.info(msg) - # loop through 1000 symbols at a time + # define what data we're trying to get + if data_type.upper() == 'BARS': + get_data_method = async_rest.get_bars_async + elif data_type.upper() == 'TRADES': + get_data_method = async_rest.get_trades_async + elif data_type.upper() == 'QUOTES': + get_data_method = async_rest.get_quotes_async + else: + raise Exception(f"Unsupoported data type: {data_type}") + + # Time period of data request + start_date = pd.Timestamp(start_date, tz=NY).date().isoformat() + end_date = pd.Timestamp(end_date, tz=NY).date().isoformat() + + # ignore timeframe argument if data_type is not 'BARS' + if data_type.upper() != 'BARS': + timeframe = None + + # Create one task for each symbol + # execute up to 1000 tasks each loop step_size = 1000 results = [] for i in range(0, len(symbols), step_size): tasks = [] for symbol in symbols[i:i+step_size]: - args = [symbol, start, end, timeframe.value] if timeframe else \ - [symbol, start, end] - tasks.append(get_data_method(data_type)(*args)) + args = [symbol, start_date, end_date, timeframe] if timeframe else \ + [symbol, start_date, end_date] + tasks.append(get_data_method(*args)) if minor >= 8: results.extend(await asyncio.gather(*tasks, return_exceptions=True)) else: results.extend(await gather_with_concurrency(500, *tasks)) - + + # notify the user of any bad reuests bad_requests = 0 for response in results: if isinstance(response, Exception): @@ -105,30 +131,6 @@ async def get_historic_data_base(symbols, data_type: DataType, start, end, return results - -# async def get_historic_bars(symbols, start, end, timeframe: TimeFrame): -# await get_historic_data_base(symbols, DataType.Bars, start, end, timeframe) - - -# async def get_historic_trades(symbols, start, end, timeframe: TimeFrame): -# await get_historic_data_base(symbols, DataType.Trades, start, end) - - -# async def get_historic_quotes(symbols, start, end, timeframe: TimeFrame): -# await get_historic_data_base(symbols, DataType.Quotes, start, end) - - -# async def main(symbols, start_time, end_time, timeframe): -# start = pd.Timestamp(start_time, tz=NY).date().isoformat() -# end = pd.Timestamp(end_time, tz=NY).date().isoformat() - - - -# # await get_historic_bars(symbols, start, end, timeframe) -# # await get_historic_trades(symbols, start, end, timeframe) -# # await get_historic_quotes(symbols, start, end, timeframe) - - if __name__ == '__main__': # Get parameters @@ -160,47 +162,38 @@ async def get_historic_data_base(symbols, data_type: DataType, start, end, # Set up variables args = parser.parse_args() - # credentials - api_key_id = args.api_key_id or os.environ.get('ALPACA_API_KEY_ID') - api_secret_key = args.api_secret_key or os.environ.get('ALPACA_API_SECRET_KEY') - # data request - symbols = args.symbols - start_date = pd.Timestamp(args.start_date, tz=NY).date().isoformat() - end_date = pd.Timestamp(args.end_date, tz=NY).date().isoformat() - if args.datatype == 'bars': - datatype = DataType.Bars - elif args.datatype == 'trades': - datatype = DataType.Trades - elif args.datatype == 'quotes': - datatype = DataType.Quotes - timeframe = args.timeframe + # make rest connection to API + async_rest = common.make_async_rest_connection(args.api_key_id, args.api_secret_key) + # storage if not os.path.exists(args.storage): common.logger.info("Creating %s directory" % (args.storage)) os.mkdir(args.storage) storage = args.storage - - # Make connection - base_url = "https://paper-api.alpaca.markets" - rest = AsyncRest(key_id=api_key_id, - secret_key=api_secret_key) - feed = "sip" # change to "iex" if only free account - - api = tradeapi.REST(key_id=api_key_id, - secret_key=api_secret_key, - base_url=URL(base_url)) - + # rest of data request + symbols = args.symbols + start_date = args.start_date + end_date = args.end_date + datatype = args.datatype.upper() + timeframe = args.timeframe + + # Request the data loop = asyncio.get_event_loop() results = loop.run_until_complete( - get_historic_data_base(symbols, datatype, start_date, end_date, timeframe) + get_historic_data(async_rest, symbols, start_date, end_date, datatype, timeframe) ) - # f = open(storage, "w") - # f.write(bars) - # f.close() - + # Stack the results into 1 dataframe + # Current it is in [(symbol0, df0), (symbol1, df1)] format + result = None + for symbol_i, df_i in results: + df_i['symbol'] = symbol_i + df_i = df_i.reset_index().set_index(['symbol', 'timestamp']) + if result is None: + result = df_i + else: + result = pd.concat([result, df_i], axis = 0, ignore_index = True) -# TODO -# split out function for non-command line use -# test functions \ No newline at end of file + # save to csv + result.to_csv(storage) \ No newline at end of file From 9715f006e158d2f7cd54565d511c2fabe8d8a58b Mon Sep 17 00:00:00 2001 From: jibi Date: Thu, 23 Dec 2021 13:34:05 -0800 Subject: [PATCH 5/7] alpaca rest historical feed completed and tested --- pyalgotrade/alpaca/common.py | 4 ++-- pyalgotrade/alpaca/restfeed.py | 30 +++++++++++++++++++----------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/pyalgotrade/alpaca/common.py b/pyalgotrade/alpaca/common.py index c7d48840b..665b1428b 100644 --- a/pyalgotrade/alpaca/common.py +++ b/pyalgotrade/alpaca/common.py @@ -32,8 +32,8 @@ def make_async_rest_connection(api_key_id = None, api_secret_key = None): # credentials - api_key_id = api_key_id or os.environ.get('ALPACA_API_KEY_ID_PAPER') - api_secret_key = api_secret_key or os.environ.get('ALPACA_API_SECRET_KEY_PAPER') + api_key_id = api_key_id or os.environ.get('ALPACA_API_KEY_ID') + api_secret_key = api_secret_key or os.environ.get('ALPACA_API_SECRET_KEY') if api_key_id is None: logger.error('Unable to retrieve API Key ID.') diff --git a/pyalgotrade/alpaca/restfeed.py b/pyalgotrade/alpaca/restfeed.py index 0adbce493..2812288d8 100644 --- a/pyalgotrade/alpaca/restfeed.py +++ b/pyalgotrade/alpaca/restfeed.py @@ -19,11 +19,21 @@ https://github.com/alpacahq/alpaca-trade-api-python/blob/master/examples/historic_async.py Example usage: - from pyalgotrade.alpaca.common import make_async_rest_connection - from pyalgotrade.alpaca.restfeed import get_historic_data - async_rest = make_async_rest_connection(api_key_id, api_secret_key) - results = get_historic_data(async_rest, ['AAPL', 'IBM'], '2021-01-01', '2021-01-10, 'QUOTES') + In a script: + from pyalgotrade.alpaca.common import make_async_rest_connection + from pyalgotrade.alpaca.restfeed import get_historic_data + + async_rest = make_async_rest_connection() + results = get_historic_data(async_rest, ['AAPL', 'IBM'], '2021-01-01', '2021-01-10, 'QUOTES') + # results = [('AAPL', pandas.DataFrame()), ('IBM', pandas.DataFrame)] + + In command line: + $ python /home/jibi/Documents/repos/pyalgotrade/pyalgotrade/alpaca/restfeed.py + --symbols AAPL IBM + --start-date 2021-01-01 + --end-date 2021-01-10 + --storage /home/jibi/Documents/mochi/sample_data/test.csv """ @@ -158,7 +168,6 @@ async def get_historic_data(async_rest, symbols, start_date, end_date, help="The path were the files will be downloaded to") # parser.add_argument("--force-download", action='store_false', # help="Force downloading even if the files exist") - # Set up variables args = parser.parse_args() @@ -166,11 +175,11 @@ async def get_historic_data(async_rest, symbols, start_date, end_date, async_rest = common.make_async_rest_connection(args.api_key_id, args.api_secret_key) # storage - if not os.path.exists(args.storage): - common.logger.info("Creating %s directory" % (args.storage)) - os.mkdir(args.storage) + # if not os.path.exists(args.storage): + # common.logger.info("Creating %s directory" % (args.storage)) + # os.mkdir(args.storage) storage = args.storage - + # rest of data request symbols = args.symbols start_date = args.start_date @@ -183,7 +192,6 @@ async def get_historic_data(async_rest, symbols, start_date, end_date, results = loop.run_until_complete( get_historic_data(async_rest, symbols, start_date, end_date, datatype, timeframe) ) - # Stack the results into 1 dataframe # Current it is in [(symbol0, df0), (symbol1, df1)] format result = None @@ -193,7 +201,7 @@ async def get_historic_data(async_rest, symbols, start_date, end_date, if result is None: result = df_i else: - result = pd.concat([result, df_i], axis = 0, ignore_index = True) + result = pd.concat([result, df_i], axis = 0, ignore_index = False) # save to csv result.to_csv(storage) \ No newline at end of file From 1167b2c3036a6c3a4ae376e6fe0f7cf2cc7e37b1 Mon Sep 17 00:00:00 2001 From: jibi Date: Sat, 8 Jan 2022 17:40:12 -0800 Subject: [PATCH 6/7] work on alpaca integration --- pyalgotrade/__init__.py | 24 + pyalgotrade/alpaca/broker.py | 180 ++-- pyalgotrade/alpaca/common.py | 48 +- .../alpaca/{restfeed.py => historicaldata.py} | 34 +- pyalgotrade/alpaca/livebroker.py | 953 +++++++++++++----- pyalgotrade/alpaca/livefeed.py | 423 ++++---- pyalgotrade/alpaca/wsclient.py | 52 +- pyalgotrade/bar.py | 2 +- pyalgotrade/dataseries/quoteds.py | 95 ++ pyalgotrade/dataseries/tradeds.py | 71 ++ pyalgotrade/quote.py | 158 +++ pyalgotrade/trade.py | 107 ++ 12 files changed, 1536 insertions(+), 611 deletions(-) rename pyalgotrade/alpaca/{restfeed.py => historicaldata.py} (86%) create mode 100644 pyalgotrade/dataseries/quoteds.py create mode 100644 pyalgotrade/dataseries/tradeds.py create mode 100644 pyalgotrade/quote.py create mode 100644 pyalgotrade/trade.py diff --git a/pyalgotrade/__init__.py b/pyalgotrade/__init__.py index 7935e8ca4..7f7a16663 100644 --- a/pyalgotrade/__init__.py +++ b/pyalgotrade/__init__.py @@ -20,3 +20,27 @@ name = "PyAlgoTrade" __version__ = "0.20" + + +class Frequency(object): + + """Enum like class for bar frequencies. Valid values are: + + * **Frequency.TRADE**: The bar represents a single trade. + * **Frequency.SECOND**: The bar summarizes the trading activity during 1 second. + * **Frequency.MINUTE**: The bar summarizes the trading activity during 1 minute. + * **Frequency.HOUR**: The bar summarizes the trading activity during 1 hour. + * **Frequency.DAY**: The bar summarizes the trading activity during 1 day. + * **Frequency.WEEK**: The bar summarizes the trading activity during 1 week. + * **Frequency.MONTH**: The bar summarizes the trading activity during 1 month. + """ + + # It is important for frequency values to get bigger for bigger windows. + TRADE = -1 + QUOTE = -1 + SECOND = 1 + MINUTE = 60 + HOUR = 60*60 + DAY = 24*60*60 + WEEK = 24*60*60*7 + MONTH = 24*60*60*31 \ No newline at end of file diff --git a/pyalgotrade/alpaca/broker.py b/pyalgotrade/alpaca/broker.py index 1a1717aab..c4d93d809 100644 --- a/pyalgotrade/alpaca/broker.py +++ b/pyalgotrade/alpaca/broker.py @@ -15,7 +15,9 @@ # limitations under the License. """ -.. moduleauthor:: Gabriel Martin Becedillas Ruiz +.. moduleauthor:: Robert Lee + +Don't think we need a custom backtester for alpaca. """ @@ -33,91 +35,91 @@ # strategy. -class BacktestingBroker(backtesting.Broker): - MIN_TRADE_USD = 5 - - """A Bitstamp backtesting broker. - - :param cash: The initial amount of cash. - :type cash: int/float. - :param barFeed: The bar feed that will provide the bars. - :type barFeed: :class:`pyalgotrade.barfeed.BarFeed` - :param fee: The fee percentage for each order. Defaults to 0.25%. - :type fee: float. - - .. note:: - * Only limit orders are supported. - * Orders are automatically set as **goodTillCanceled=True** and **allOrNone=False**. - * BUY_TO_COVER orders are mapped to BUY orders. - * SELL_SHORT orders are mapped to SELL orders. - """ - - def __init__(self, cash, barFeed, fee=0.0025): - commission = backtesting.TradePercentage(fee) - super(BacktestingBroker, self).__init__(cash, barFeed, commission) - - def getInstrumentTraits(self, instrument): - return common.BTCTraits() - - def submitOrder(self, order): - if order.isInitial(): - # Override user settings based on Bitstamp limitations. - order.setAllOrNone(False) - order.setGoodTillCanceled(True) - return super(BacktestingBroker, self).submitOrder(order) - - def createMarketOrder(self, action, instrument, quantity, onClose=False): - raise Exception("Market orders are not supported") - - def createLimitOrder(self, action, instrument, limitPrice, quantity): - if instrument != common.btc_symbol: - raise Exception("Only BTC instrument is supported") - - if action == broker.Order.Action.BUY_TO_COVER: - action = broker.Order.Action.BUY - elif action == broker.Order.Action.SELL_SHORT: - action = broker.Order.Action.SELL - - if limitPrice * quantity < BacktestingBroker.MIN_TRADE_USD: - raise Exception("Trade must be >= %s" % (BacktestingBroker.MIN_TRADE_USD)) - - if action == broker.Order.Action.BUY: - # Check that there is enough cash. - fee = self.getCommission().calculate(None, limitPrice, quantity) - cashRequired = limitPrice * quantity + fee - if cashRequired > self.getCash(False): - raise Exception("Not enough cash") - elif action == broker.Order.Action.SELL: - # Check that there are enough coins. - if quantity > self.getShares(common.btc_symbol): - raise Exception("Not enough %s" % (common.btc_symbol)) - else: - raise Exception("Only BUY/SELL orders are supported") - - return super(BacktestingBroker, self).createLimitOrder(action, instrument, limitPrice, quantity) - - def createStopOrder(self, action, instrument, stopPrice, quantity): - raise Exception("Stop orders are not supported") - - def createStopLimitOrder(self, action, instrument, stopPrice, limitPrice, quantity): - raise Exception("Stop limit orders are not supported") - - -class PaperTradingBroker(BacktestingBroker): - """A Bitstamp paper trading broker. - - :param cash: The initial amount of cash. - :type cash: int/float. - :param barFeed: The bar feed that will provide the bars. - :type barFeed: :class:`pyalgotrade.barfeed.BarFeed` - :param fee: The fee percentage for each order. Defaults to 0.5%. - :type fee: float. - - .. note:: - * Only limit orders are supported. - * Orders are automatically set as **goodTillCanceled=True** and **allOrNone=False**. - * BUY_TO_COVER orders are mapped to BUY orders. - * SELL_SHORT orders are mapped to SELL orders. - """ - - pass +# class BacktestingBroker(backtesting.Broker): +# MIN_TRADE_USD = 5 + +# """An Alpaca backtesting broker. + +# :param cash: The initial amount of cash. +# :type cash: int/float. +# :param barFeed: The bar feed that will provide the bars. +# :type barFeed: :class:`pyalgotrade.barfeed.BarFeed` +# :param fee: The fee percentage for each order. Defaults to 0.25%. +# :type fee: float. + +# .. note:: +# * Only limit orders are supported. +# * Orders are automatically set as **goodTillCanceled=True** and **allOrNone=False**. +# * BUY_TO_COVER orders are mapped to BUY orders. +# * SELL_SHORT orders are mapped to SELL orders. +# """ + +# def __init__(self, cash, barFeed, fee=0.0025): +# commission = backtesting.TradePercentage(fee) +# super(BacktestingBroker, self).__init__(cash, barFeed, commission) + +# # def getInstrumentTraits(self, instrument): +# # return common.BTCTraits() + +# def submitOrder(self, order): +# # if order.isInitial(): +# # # Override user settings based on Bitstamp limitations. +# # order.setAllOrNone(False) +# # order.setGoodTillCanceled(True) +# return super(BacktestingBroker, self).submitOrder(order) + +# # def createMarketOrder(self, action, instrument, quantity, onClose=False): +# # raise Exception("Market orders are not supported") + +# def createLimitOrder(self, action, instrument, limitPrice, quantity): +# if instrument != common.btc_symbol: +# raise Exception("Only BTC instrument is supported") + +# if action == broker.Order.Action.BUY_TO_COVER: +# action = broker.Order.Action.BUY +# elif action == broker.Order.Action.SELL_SHORT: +# action = broker.Order.Action.SELL + +# if limitPrice * quantity < BacktestingBroker.MIN_TRADE_USD: +# raise Exception("Trade must be >= %s" % (BacktestingBroker.MIN_TRADE_USD)) + +# if action == broker.Order.Action.BUY: +# # Check that there is enough cash. +# fee = self.getCommission().calculate(None, limitPrice, quantity) +# cashRequired = limitPrice * quantity + fee +# if cashRequired > self.getCash(False): +# raise Exception("Not enough cash") +# elif action == broker.Order.Action.SELL: +# # Check that there are enough coins. +# if quantity > self.getShares(common.btc_symbol): +# raise Exception("Not enough %s" % (common.btc_symbol)) +# else: +# raise Exception("Only BUY/SELL orders are supported") + +# return super(BacktestingBroker, self).createLimitOrder(action, instrument, limitPrice, quantity) + +# def createStopOrder(self, action, instrument, stopPrice, quantity): +# raise Exception("Stop orders are not supported") + +# def createStopLimitOrder(self, action, instrument, stopPrice, limitPrice, quantity): +# raise Exception("Stop limit orders are not supported") + + +# class PaperTradingBroker(BacktestingBroker): +# """A Bitstamp paper trading broker. + +# :param cash: The initial amount of cash. +# :type cash: int/float. +# :param barFeed: The bar feed that will provide the bars. +# :type barFeed: :class:`pyalgotrade.barfeed.BarFeed` +# :param fee: The fee percentage for each order. Defaults to 0.5%. +# :type fee: float. + +# .. note:: +# * Only limit orders are supported. +# * Orders are automatically set as **goodTillCanceled=True** and **allOrNone=False**. +# * BUY_TO_COVER orders are mapped to BUY orders. +# * SELL_SHORT orders are mapped to SELL orders. +# """ + +# pass diff --git a/pyalgotrade/alpaca/common.py b/pyalgotrade/alpaca/common.py index 665b1428b..2aba4b16a 100644 --- a/pyalgotrade/alpaca/common.py +++ b/pyalgotrade/alpaca/common.py @@ -15,12 +15,13 @@ # limitations under the License. """ -.. moduleauthor:: Gabriel Martin Becedillas Ruiz +.. moduleauthor:: Robert Lee """ -from enum import Enum import os +import alpaca_trade_api as tradeapi from alpaca_trade_api.rest_async import AsyncRest +from alpaca_trade_api.stream import Stream import pyalgotrade.logger from pyalgotrade import broker @@ -28,9 +29,19 @@ logger = pyalgotrade.logger.getLogger("alpaca") +def make_connection(connection_type, api_key_id = None, api_secret_key = None): + """Makes a connection to Alpaca. + + https://alpaca.markets/docs/api-documentation/api-v2/ + + Args: + connection_type: The connection to make to Alpaca. One of [rest, async_rest, stream]. + api_key_id (str, optional): If none, looks at the environment variable ALPACA_API_KEY_ID. + Defaults to None. + api_secret_key (str, optional): If none, looks at the environment variable ALPACA_API_SECRET_KEY. + Defaults to None. + """ -def make_async_rest_connection(api_key_id = None, api_secret_key = None): - # credentials api_key_id = api_key_id or os.environ.get('ALPACA_API_KEY_ID') api_secret_key = api_secret_key or os.environ.get('ALPACA_API_SECRET_KEY') @@ -39,28 +50,15 @@ def make_async_rest_connection(api_key_id = None, api_secret_key = None): logger.error('Unable to retrieve API Key ID.') if api_key_id is None: logger.error('Unable to retrieve API Secret Key.') - - rest = AsyncRest(key_id=api_key_id, - secret_key=api_secret_key) - - return rest - - - - - - - - - - - - - - - - + if connection_type == 'async_rest': + connection = AsyncRest(key_id=api_key_id, secret_key=api_secret_key) + elif connection_type == 'rest': + connection = tradeapi.REST(key_id = api_key_id, secret_key = api_secret_key) + elif connection_type == 'stream': + connection = Stream(data_feed = 'IEX', key_id=api_key_id, secret_key=api_secret_key, raw_data = True) + + return connection # btc_symbol = "BTC" diff --git a/pyalgotrade/alpaca/restfeed.py b/pyalgotrade/alpaca/historicaldata.py similarity index 86% rename from pyalgotrade/alpaca/restfeed.py rename to pyalgotrade/alpaca/historicaldata.py index 2812288d8..21bf96f5a 100644 --- a/pyalgotrade/alpaca/restfeed.py +++ b/pyalgotrade/alpaca/historicaldata.py @@ -21,15 +21,15 @@ Example usage: In a script: - from pyalgotrade.alpaca.common import make_async_rest_connection - from pyalgotrade.alpaca.restfeed import get_historic_data + from pyalgotrade.alpaca.common import make_connection + from pyalgotrade.alpaca.historicaldata import get_historical_data - async_rest = make_async_rest_connection() - results = get_historic_data(async_rest, ['AAPL', 'IBM'], '2021-01-01', '2021-01-10, 'QUOTES') + async_rest = make_connection(connection_type = 'async_rest') + results = await get_historical_data(async_rest, ['AAPL', 'IBM'], '2021-01-01', '2021-01-10, 'QUOTES') # results = [('AAPL', pandas.DataFrame()), ('IBM', pandas.DataFrame)] In command line: - $ python /home/jibi/Documents/repos/pyalgotrade/pyalgotrade/alpaca/restfeed.py + $ python /home/jibi/Documents/repos/pyalgotrade/pyalgotrade/alpaca/historicaldata.py --symbols AAPL IBM --start-date 2021-01-01 --end-date 2021-01-10 @@ -38,32 +38,23 @@ """ - -from enum import Enum -import datetime -import os import sys import asyncio import argparse import pandas as pd -import alpaca_trade_api as tradeapi -from alpaca_trade_api.rest import TimeFrame, URL -from alpaca_trade_api.rest_async import gather_with_concurrency, AsyncRest +from alpaca_trade_api.rest_async import gather_with_concurrency from pyalgotrade.alpaca import common -# from pyalgotrade.alpaca import livefeed - -# LiveTradeFeed = livefeed.LiveTradeFeed NY = 'America/New_York' -async def get_historic_data(async_rest, symbols, start_date, end_date, +async def get_historical_data(async_rest, symbols, start_date, end_date, data_type = 'BARS', timeframe = '1Day'): """ - Retrieve historic data for multiple symbols using Alpaca's get_[datatype]_async + Retrieve historical data for multiple symbols using Alpaca's get_[datatype]_async from the AsyncRest object. Args: @@ -78,11 +69,6 @@ async def get_historic_data(async_rest, symbols, start_date, end_date, Returns: [(symbol, df),]: List of tuples of (symbol, pandas DataFrame) - - Usage: - async_rest = make_async_rest_connetion() - symbols, dfs = - """ # Check Python version major = sys.version_info.major @@ -172,7 +158,7 @@ async def get_historic_data(async_rest, symbols, start_date, end_date, args = parser.parse_args() # make rest connection to API - async_rest = common.make_async_rest_connection(args.api_key_id, args.api_secret_key) + async_rest = common.make_connection('async_rest', args.api_key_id, args.api_secret_key) # storage # if not os.path.exists(args.storage): @@ -190,7 +176,7 @@ async def get_historic_data(async_rest, symbols, start_date, end_date, # Request the data loop = asyncio.get_event_loop() results = loop.run_until_complete( - get_historic_data(async_rest, symbols, start_date, end_date, datatype, timeframe) + get_historical_data(async_rest, symbols, start_date, end_date, datatype, timeframe) ) # Stack the results into 1 dataframe # Current it is in [(symbol0, df0), (symbol1, df1)] format diff --git a/pyalgotrade/alpaca/livebroker.py b/pyalgotrade/alpaca/livebroker.py index 6bb8edf7e..5b4f8e169 100644 --- a/pyalgotrade/alpaca/livebroker.py +++ b/pyalgotrade/alpaca/livebroker.py @@ -15,164 +15,73 @@ # limitations under the License. """ -.. moduleauthor:: Gabriel Martin Becedillas Ruiz +.. moduleauthor:: Robert Lee """ +from os import kill import threading import time +import alpaca from six.moves import queue +from ws4py.websocket import EchoWebSocket -from pyalgotrade import broker -from pyalgotrade.bitstamp import httpclient -from pyalgotrade.bitstamp import common - - -def build_order_from_open_order(openOrder, instrumentTraits): - if openOrder.isBuy(): - action = broker.Order.Action.BUY - elif openOrder.isSell(): - action = broker.Order.Action.SELL - else: - raise Exception("Invalid order type") - - ret = broker.LimitOrder(action, common.btc_symbol, openOrder.getPrice(), openOrder.getAmount(), instrumentTraits) - ret.setSubmitted(openOrder.getId(), openOrder.getDateTime()) - ret.setState(broker.Order.State.ACCEPTED) - return ret - - -class TradeMonitor(threading.Thread): - POLL_FREQUENCY = 2 - - # Events - ON_USER_TRADE = 1 - - def __init__(self, httpClient): - super(TradeMonitor, self).__init__() - self.__lastTradeId = -1 - self.__httpClient = httpClient - self.__queue = queue.Queue() - self.__stop = False +import zmq - def _getNewTrades(self): - userTrades = self.__httpClient.getUserTransactions(httpclient.HTTPClient.UserTransactionType.MARKET_TRADE) - - # Get the new trades only. - ret = [t for t in userTrades if t.getId() > self.__lastTradeId] - - # Sort by id, so older trades first. - return sorted(ret, key=lambda t: t.getId()) - - def getQueue(self): - return self.__queue +from pyalgotrade import broker +from pyalgotrade.alpaca import httpclient +from pyalgotrade.alpaca import common +from alpaca.livefeed import EventQueuer - def start(self): - trades = self._getNewTrades() - # Store the last trade id since we'll start processing new ones only. - if len(trades): - self.__lastTradeId = trades[-1].getId() - common.logger.info("Last trade found: %d" % (self.__lastTradeId)) - - super(TradeMonitor, self).start() - - def run(self): - while not self.__stop: - try: - trades = self._getNewTrades() - if len(trades): - self.__lastTradeId = trades[-1].getId() - common.logger.info("%d new trade/s found" % (len(trades))) - self.__queue.put((TradeMonitor.ON_USER_TRADE, trades)) - except Exception as e: - common.logger.critical("Error retrieving user transactions", exc_info=e) - - time.sleep(TradeMonitor.POLL_FREQUENCY) +from observer import Event - def stop(self): - self.__stop = True +class LiveBroker(broker.Broker): + """An Alpaca live broker. + The live broker listens to a ZMQ SUB socket for trade updates, + and uses a rest connection to get account info and place trades. -class LiveBroker(broker.Broker): - """A Bitstamp live broker. - - :param clientId: Client id. - :type clientId: string. - :param key: API key. - :type key: string. - :param secret: API secret. - :type secret: string. - - - .. note:: - * Only limit orders are supported. - * Orders are automatically set as **goodTillCanceled=True** and **allOrNone=False**. - * BUY_TO_COVER orders are mapped to BUY orders. - * SELL_SHORT orders are mapped to SELL orders. - * API access permissions should include: - - * Account balance - * Open orders - * Buy limit order - * User transactions - * Cancel order - * Sell limit order + :param liveFeedAddress: Address to which the ZMQ SUB socket should be connected. + :type liveFeedAddress: string. + :param restConnection: An Alpaca rest connection from alpaca_trade_api. + :type restConnection: string. """ QUEUE_TIMEOUT = 0.01 - def __init__(self, clientId, key, secret): + def __init__(self, liveFeedAddress, restConnection): super(LiveBroker, self).__init__() + + self._restConnection = restConnection + + self.__tradeMonitor = EventQueuer(liveFeedAddress) self.__stop = False - self.__httpClient = self.buildHTTPClient(clientId, key, secret) - self.__tradeMonitor = TradeMonitor(self.__httpClient) - self.__cash = 0 - self.__shares = {} - self.__activeOrders = {} - - def _registerOrder(self, order): - assert(order.getId() not in self.__activeOrders) - assert(order.getId() is not None) - self.__activeOrders[order.getId()] = order - - def _unregisterOrder(self, order): - assert(order.getId() in self.__activeOrders) - assert(order.getId() is not None) - del self.__activeOrders[order.getId()] - - # Factory method for testing purposes. - def buildHTTPClient(self, clientId, key, secret): - return httpclient.HTTPClient(clientId, key, secret) - - def refreshAccountBalance(self): - """Refreshes cash and BTC balance.""" - - self.__stop = True # Stop running in case of errors. - common.logger.info("Retrieving account balance.") - balance = self.__httpClient.getAccountBalance() - - # Cash - self.__cash = round(balance.getUSDAvailable(), 2) - common.logger.info("%s USD" % (self.__cash)) - # BTC - btc = balance.getBTCAvailable() - if btc: - self.__shares = {common.btc_symbol: btc} + + def __getattr__(self, name): + """Transfer methods of the underlying api rest connection to the live broker. + """ + if hasattr(self._restConnection, name): + return getattr(self._restConnection, name) else: - self.__shares = {} - common.logger.info("%s BTC" % (btc)) - - self.__stop = False # No errors. Keep running. - - def refreshOpenOrders(self): - self.__stop = True # Stop running in case of errors. - common.logger.info("Retrieving open orders.") - openOrders = self.__httpClient.getOpenOrders() - for openOrder in openOrders: - self._registerOrder(build_order_from_open_order(openOrder, self.getInstrumentTraits(common.btc_symbol))) - - common.logger.info("%d open order/s found" % (len(openOrders))) - self.__stop = False # No errors. Keep running. + raise AttributeError + + @property + def account(self): + return self._restConnection.getAccount() + + @property + def cash(self): + return self.account['cash'] + + @property + def openPositions(self): + return self._restConnection.list_positions() + + @property + def openOrders(self): + orders = self._restConnection.list_orders(status = 'open') + orders = map(fromAlpacaOrder, orders) + return {order['client_order_id']: order for order in orders} def _startTradeMonitor(self): self.__stop = True # Stop running in case of errors. @@ -180,37 +89,10 @@ def _startTradeMonitor(self): self.__tradeMonitor.start() self.__stop = False # No errors. Keep running. - def _onUserTrades(self, trades): - for trade in trades: - order = self.__activeOrders.get(trade.getOrderId()) - if order is not None: - fee = trade.getFee() - fillPrice = trade.getBTCUSD() - btcAmount = trade.getBTC() - dateTime = trade.getDateTime() - - # Update cash and shares. - self.refreshAccountBalance() - # Update the order. - orderExecutionInfo = broker.OrderExecutionInfo(fillPrice, abs(btcAmount), fee, dateTime) - order.addExecutionInfo(orderExecutionInfo) - if not order.isActive(): - self._unregisterOrder(order) - # Notify that the order was updated. - if order.isFilled(): - eventType = broker.OrderEvent.Type.FILLED - else: - eventType = broker.OrderEvent.Type.PARTIALLY_FILLED - self.notifyOrderEvent(broker.OrderEvent(order, eventType, orderExecutionInfo)) - else: - common.logger.info("Trade %d refered to order %d that is not active" % (trade.getId(), trade.getOrderId())) # BEGIN observer.Subject interface def start(self): super(LiveBroker, self).start() - self.refreshAccountBalance() - self.refreshOpenOrders() - self._startTradeMonitor() def stop(self): self.__stop = True @@ -218,28 +100,18 @@ def stop(self): self.__tradeMonitor.stop() def join(self): - if self.__tradeMonitor.isAlive(): - self.__tradeMonitor.join() + pass def eof(self): return self.__stop def dispatch(self): - # Switch orders from SUBMITTED to ACCEPTED. - ordersToProcess = list(self.__activeOrders.values()) - for order in ordersToProcess: - if order.isSubmitted(): - order.switchState(broker.Order.State.ACCEPTED) - self.notifyOrderEvent(broker.OrderEvent(order, broker.OrderEvent.Type.ACCEPTED, None)) - - # Dispatch events from the trade monitor. try: - eventType, eventData = self.__tradeMonitor.getQueue().get(True, LiveBroker.QUEUE_TIMEOUT) + update = self.__tradeMonitor.getQueue().get( + block = True, timeout = LiveBroker.QUEUE_TIMEOUT) + order = update['order'] + self.notifyOrderEvent(fromAlpacaOrder(order)) - if eventType == TradeMonitor.ON_USER_TRADE: - self._onUserTrades(eventData) - else: - common.logger.error("Invalid event received to dispatch: %s - %s" % (eventType, eventData)) except queue.Empty: pass @@ -252,81 +124,680 @@ def peekDateTime(self): # BEGIN broker.Broker interface def getCash(self, includeShort=True): - return self.__cash + return self.cash def getInstrumentTraits(self, instrument): - return common.BTCTraits() + return broker.IntegerTraits() def getShares(self, instrument): - return self.__shares.get(instrument, 0) + return [pos for pos in self.openPositions if pos['symbol'] == instrument] def getPositions(self): - return self.__shares + return self.openPositions def getActiveOrders(self, instrument=None): - return list(self.__activeOrders.values()) - - def submitOrder(self, order): - if order.isInitial(): - # Override user settings based on Bitstamp limitations. - order.setAllOrNone(False) - order.setGoodTillCanceled(True) - - if order.isBuy(): - bitstampOrder = self.__httpClient.buyLimit(order.getLimitPrice(), order.getQuantity()) - else: - bitstampOrder = self.__httpClient.sellLimit(order.getLimitPrice(), order.getQuantity()) - - order.setSubmitted(bitstampOrder.getId(), bitstampOrder.getDateTime()) - self._registerOrder(order) - # Switch from INITIAL -> SUBMITTED - # IMPORTANT: Do not emit an event for this switch because when using the position interface - # the order is not yet mapped to the position and Position.onOrderUpdated will get called. - order.switchState(broker.Order.State.SUBMITTED) + if instrument is not None: + return [openOrder for openOrder in self.openOrders if openOrder.instrument == instrument] else: - raise Exception("The order was already processed") - - def createMarketOrder(self, action, instrument, quantity, onClose=False): - raise Exception("Market orders are not supported") - - def createLimitOrder(self, action, instrument, limitPrice, quantity): - if instrument != common.btc_symbol: - raise Exception("Only BTC instrument is supported") + return self.openOrders - if action == broker.Order.Action.BUY_TO_COVER: - action = broker.Order.Action.BUY - elif action == broker.Order.Action.SELL_SHORT: - action = broker.Order.Action.SELL - - if action not in [broker.Order.Action.BUY, broker.Order.Action.SELL]: - raise Exception("Only BUY/SELL orders are supported") + def submitOrder(self, order): + self._restConnection.submit_order(**toAlpacaOrder(order)) - instrumentTraits = self.getInstrumentTraits(instrument) - limitPrice = round(limitPrice, 2) - quantity = instrumentTraits.roundQuantity(quantity) - return broker.LimitOrder(action, instrument, limitPrice, quantity, instrumentTraits) + def cancelOrder(self, order): + self._restConnection.cancel_order(order.orderId) - def createStopOrder(self, action, instrument, stopPrice, quantity): - raise Exception("Stop orders are not supported") + # Notify that the order was canceled. + self.notifyOrderEvent(AlpacaOrder.OrderEvent(order, AlpacaOrder.OrderEvent.Type.CANCELED, "User requested cancellation")) - def createStopLimitOrder(self, action, instrument, stopPrice, limitPrice, quantity): - raise Exception("Stop limit orders are not supported") + # END broker.Broker interface - def cancelOrder(self, order): - activeOrder = self.__activeOrders.get(order.getId()) - if activeOrder is None: - raise Exception("The order is not active anymore") - if activeOrder.isFilled(): - raise Exception("Can't cancel order that has already been filled") + def getClock(self): + return self._restConnection.get_clock() + + def getCalendar(self, start = None, end = None): + return self._restConnection.get_calendar(start = start, end = end) + + def getPortfolioHistory(self, + dateStart = None, dateEnd = None, period = None, + timeframe = None, extendedHours = None): + return self._restConnection.get_portfolio_history( + date_start = dateStart, + date_end = dateEnd, + period = period, + timeframe = timeframe, + extended_hours = extendedHours + ) + +# Types of orders +class AlpacaOrder(broker.Order): + """Base class for Alpaca orders. + Contains a few more fields than the broker.Order class. + """ - self.__httpClient.cancelOrder(order.getId()) - self._unregisterOrder(order) - order.switchState(broker.Order.State.CANCELED) + class State: + # https://alpaca.markets/docs/trading-on-alpaca/orders/#order-lifecycle + + # Typical states + NEW = 1 + PARTIALLY_FILLED = 2 + FILLED = 3 + DONE_FOR_DAY = 4 + CANCELED = 5 + EXPIRED = 6 + REPLACED = 7 + PENDING_CANCEL = 8 + PENDING_REPLACE = 9 + + # Less common states + ACCEPTED = 101 + PENDING_NEW = 102 + ACCEPTED_FOR_BIDDING = 103 + STOPPED = 104 + REJECTED = 105 + SUSPENDED = 106 + CALCULATED = 107 + + @classmethod + def toString(cls, state): + if state == cls.NEW: + return 'new' + elif state == cls.PARTIALLY_FILLED: + return 'partially_filled' + elif state == cls.FILLED: + return 'filled' + elif state == cls.DONE_FOR_DAY: + return 'done_for_day' + elif state == cls.CANCELED: + return 'canceled' + elif state == cls.EXPIRED: + return 'expired' + elif state == cls.REPLACED: + return 'replaced' + elif state == cls.PENDING_CANCEL: + return 'pending_cancel' + elif state == cls.PENDING_REPLACE: + return 'pending_replace' + elif state == cls.ACCEPTED: + return 'accepted' + elif state == cls.PENDING_NEW: + return 'pending_new' + elif state == cls.ACCEPTED_FOR_BIDDING: + return 'accepted_for_bidding' + elif state == cls.STOPPED: + return 'stopped' + elif state == cls.REJECTED: + return 'rejected' + elif state == cls.SUSPENDED: + return 'suspended' + elif state == cls.CALCULATED: + return 'calculated' + else: + raise Exception("Invalid state") + + @classmethod + def fromString(cls, strState): + if strState == 'new': + return cls.NEW + elif strState == 'partially_filled': + return cls.PARTIALLY_FILLED + elif strState == 'filled': + return cls.FILLED + elif strState == 'done_for_day': + return cls.DONE_FOR_DAY + elif strState == 'canceled': + return cls.CANCELED + elif strState == 'expired': + return cls.EXPIRED + elif strState == 'replaced': + return cls.REPLACED + elif strState == 'pending_cancel': + return cls.PENDING_CANCEL + elif strState == 'pending_replace': + return cls.PENDING_REPLACE + elif strState == 'accepted': + return cls.ACCEPTED + elif strState == 'pending_new': + return cls.PENDING_NEW + elif strState == 'accepted_for_bidding': + return cls.ACCEPTED_FOR_BIDDING + elif strState == 'stopped': + return cls.STOPPED + elif strState == 'rejected': + return cls.REJECTED + elif strState == 'suspended': + return cls.SUSPENDED + elif strState == 'calculated': + return cls.CALCULATED + else: + raise Exception('Invalid order state') + + class Type(broker.Order.Type): + # MARKET = 1 + # LIMIT = 2 + # STOP = 3 + # STOP_LIMIT = 4 + TRAILING_STOP = 5 + + @classmethod + def toString(cls, type_): + if type_ == 'market': + return cls.MARKET + elif type_ == 'limit': + return cls.LIMIT + elif type_ == 'stop': + return cls.STOP + elif type_ == 'stop_limit': + return cls.STOP_LIMIT + elif type == 'trailing_stop': + return cls.TRAILING_STOP + else: + raise Exception('Inavlid order type') + + @classmethod + def fromString(cls, strType): + if strType == cls.MARKET: + return 'market' + elif strType == cls.LIMIT: + return 'limit' + elif strType == cls.STOP: + return 'stop' + elif strType == cls.STOP_LIMIT: + return 'stop_limit' + elif strType == cls.TRAILING_STOP: + return 'trailing_stop' + else: + raise Exception('Invalid order type') + + class OrderClass(object): + SIMPLE = 1 + BRACKET = 2 + OCO = 3 + OTO = 4 + + @classmethod + def toString(cls, orderclass): + if orderclass == cls.SIMPLE: + return 'simple' + elif orderclass == cls.BRACKET: + return 'bracket' + elif orderclass == cls.OCO: + return 'oco' + elif orderclass == cls.OTO: + return 'oto' + else: + raise Exception('Inavlid order class') + + @classmethod + def fromString(cls, strOrderClass): + if strOrderClass == 'simple': + return cls.SIMPLE + elif strOrderClass == 'bracket': + return cls.BRACKET + elif strOrderClass == 'oco': + return cls.OCO + elif strOrderClass == 'oto': + return cls.OTO + else: + raise Exception('Inavlid order class') - # Update cash and shares. - self.refreshAccountBalance() + class Action(broker.Order.Action): - # Notify that the order was canceled. - self.notifyOrderEvent(broker.OrderEvent(order, broker.OrderEvent.Type.CANCELED, "User requested cancellation")) + @classmethod + def toString(cls, action): + if action == 'buy': + return cls.BUY + elif action == 'sell': + return cls.SELL + else: + raise Exception('Inavlid order action') + + @classmethod + def fromString(cls, strAction): + if strAction == cls.BUY: + return 'buy' + elif strAction == cls.BUY_TO_COVER: + return 'buy' + elif strAction == cls.SELL: + return 'sell' + elif strAction == cls.SELL_SHORT: + return 'sell' + else: + raise Exception('Inavlid order action') + + class TimeInForce(object): + # https://alpaca.markets/docs/trading-on-alpaca/orders/#time-in-force + + DAY = 1 # good for day + GTC = 2 # good till canceled + OPG = 3 # market on open / limit on open + CLS = 4 # market on close / limt on close + IOC = 5 # immediate or cancel + FOK = 6 # fill or kill + + @classmethod + def toString(cls, timeInForce): + if timeInForce == cls.DAY: + return 'day' + elif timeInForce == cls.GTC: + return 'gtc' + elif timeInForce == cls.OPG: + return 'opg' + elif timeInForce == cls.CLS: + return 'cls' + elif timeInForce == cls.IOC: + return 'ioc' + elif timeInForce == cls.FOK: + return 'fok' + else: + raise Exception('Inavlid order time in force') + + @classmethod + def fromString(cls, strTimeInForce): + if strTimeInForce == 'day': + return cls.DAY + elif strTimeInForce == 'gtc': + return cls.GTC + elif strTimeInForce == 'opg': + return cls.OPG + elif strTimeInForce == 'cls': + return cls.CLS + elif strTimeInForce == 'ioc': + return cls.IOC + elif strTimeInForce == 'fok': + return cls.fok + else: + raise Exception('Inavlid order time in force') + + + def __init__( + self, + # broker.Order attributes + type_, + action, + instrument, + quantity, + instrumentTraits = broker.InstrumentTraits(), + # Alpaca-specific attributes + orderId = None, + clientOrderId = None, + createdAt = None, + updatedAt = None, + submittedAt = None, + filledAt = None, + expiredAt = None, + canceledAt = None, + failedAt = None, + replacedAt = None, + replacedBy = None, + replaces = None, + assetId = None, + assetClass = None, + notional = None, + filledQuantity = None, + filledAveragePrice = None, + orderClass = None, + timeInForce = None, + limiPrice = None, + stopPrice = None, + extendedHours = False, + legs = None, + trailPercent = None, + trailPrice = None, + hwm = None, + # for bracket orders + takeProfit = None, + stopLossStop = None, + stopLossLimit = None + ): + super(AlpacaOrder, self).__init__(type_, action, instrument, quantity, instrumentTraits) + self.orderId = orderId + self.clientOrderId = clientOrderId + self.createdAt = createdAt + self.updatedAt = updatedAt + self.submittedAt = submittedAt + self.filledAt = filledAt + self.expiredAt = expiredAt + self.canceledAt = canceledAt + self.failedAt = failedAt + self.replacedAt = replacedAt + self.replacedBy = replacedBy + self.replaces = replaces + self.assetId = assetId + self.assetClass = assetClass + self.notional = notional + self.filledQuantity = filledQuantity + self.filledAveragePrice = filledAveragePrice + self.orderClass = orderClass + self.timeInForce = timeInForce + self.limiPrice = limiPrice + self.stopPrice = stopPrice + self.extendedHours = extendedHours + self.legs = legs + self.trailPercent = trailPercent + self.trailPrice = trailPrice + self.hwm = hwm + # for bracket orders + self.takeProfit = takeProfit + self.stopLossStop = stopLossStop + self.stopLossLimit = stopLossLimit + +class MarketOrder(AlpacaOrder): + """A market order is a request to buy or sell a security at the + currently available market price. + """ + def __init__(self, action, instrument, quantity = None, notional = None, **kwargs): + super(MarketOrder, self).__init__( + type_ = AlpacaOrder.Type.MARKET, + action = action, + instrument = instrument, + quantity = quantity, + notional = notional, + **kwargs + ) + +class LimitOrder(AlpacaOrder): + """A limit order is an order to buy or sell at a specified price or better. + """ + def __init__(self, action, instrument, limitPrice, quantity = None, notional = None, **kwargs): + super(LimitOrder, self).__init__( + type_ = AlpacaOrder.Type.LIMIT, + action = action, + instrument = instrument, + limitPrice = limitPrice, + quantity = quantity, + notional = notional, + **kwargs + ) + +class StopOrder(AlpacaOrder): + """A stop (market) order is an order to buy or sell a security + when its price moves past a particular point, + ensuring a higher probability of achieving a predetermined + entry or exit price. + + NOTE: Alpaca converts buy stop orders into stop limit orders + with a limit price that is 4% higher than a stop price < $50 + (or 2.5% higher than a stop price >= $50). + """ + def __init__(self, action, instrument, stopPrice, quantity = None, notional = None, **kwargs): + super(LimitOrder, self).__init__( + type_ = AlpacaOrder.Type.LIMIT, + action = action, + instrument = instrument, + stopPrice = stopPrice, + quantity = quantity, + notional = notional, + **kwargs + ) + +class StopLimitOrder(AlpacaOrder): + """A stop-limit order is a conditional trade over a set time frame + that combines the features of a stop order with those of a limit order + and is used to mitigate risk. + """ + def __init__(self, action, instrument, stopPrice, limitPrice, quantity = None, notional = None, **kwargs): + super(LimitOrder, self).__init__( + type_ = AlpacaOrder.Type.LIMIT, + action = action, + instrument = instrument, + stopPrice = stopPrice, + limitPrice = limitPrice, + quantity = quantity, + notional = notional, + **kwargs + ) + +class MarketOnOpenOrder(MarketOrder): + def __init__(self, action, instrument, quantity = None, notional = None, **kwargs): + super(MarketOrder, self).__init__( + type_ = AlpacaOrder.Type.MARKET, + action = action, + instrument = instrument, + quantity = quantity, + notional = notional, + timeInForce = AlpacaOrder.TimeInForce.OPG + **kwargs + ) + +class MarketOnCloseOrder(MarketOrder): + def __init__(self, action, instrument, quantity = None, notional = None, **kwargs): + super(MarketOrder, self).__init__( + type_ = AlpacaOrder.Type.MARKET, + action = action, + instrument = instrument, + quantity = quantity, + notional = notional, + timeInForce = AlpacaOrder.TimeInForce.CLS + **kwargs + ) + +class LimitOnOpenOrder(LimitOrder): + def __init__(self, action, instrument, limitPrice, quantity = None, notional = None, **kwargs): + super(LimitOrder, self).__init__( + type_ = AlpacaOrder.Type.LIMIT, + action = action, + instrument = instrument, + limitPrice = limitPrice, + quantity = quantity, + notional = notional, + timeInForce = AlpacaOrder.TimeInForce.OPG + **kwargs + ) + +class LimitOnCloseOrder(LimitOrder): + def __init__(self, action, instrument, limitPrice, quantity = None, notional = None, **kwargs): + super(LimitOrder, self).__init__( + type_ = AlpacaOrder.Type.LIMIT, + action = action, + instrument = instrument, + limitPrice = limitPrice, + quantity = quantity, + notional = notional, + timeInForce = AlpacaOrder.TimeInForce.CLS + **kwargs + ) + +class BracketOrder(AlpacaOrder): + """A bracket order is a chain of three orders that can be used to + manage your position entry and exit. + + Args: + openingOrder (AlpacaOrder): The opening order for the bracket order. + takeProfitLimit (numeric): The limit price for the exiting take profit limit order. + stopLossStop (numeric): The price to trigger the exiting stop loss order. + stopLossLimit (numeric, Optional): The limit price for the exiting stop loss order + if the stop loss order is a limit order. + + Returns: + AlpacaOrder: A bracket order. + """ + + def __new__(cls, openingOrder, takeProfitLimit, stopLossStop, stopLossLimit = None): + + # check exiting order conditions + if openingOrder.action == AlpacaOrder.Action.BUY: + assert takeProfitLimit > stopLossStop, \ + 'Take profit price must be greater than stop price for buy orders.' + elif openingOrder.action == AlpacaOrder.Action.SELL: + assert takeProfitLimit < stopLossStop, \ + 'take profit price must be less than stop price for sell orders.' + else: + raise Exception('Invalid order action: {openingOrder.action}') + + if openingOrder.extendedHours: + raise Exception( + 'Extended hours are not supported for bracket orders, ' + \ + 'converting to regular hours order.' + ) + + if openingOrder.TimeInForce not in [AlpacaOrder.TimeInForce.DAY, AlpacaOrder.TimeInForce.GTC]: + raise Exception( + 'Time in force must be "day" or "gtc".' + ) + + order = openingOrder + order.takeProfitLimit = takeProfitLimit + order.stopLossStop = stopLossStop + order.stopLossLimit = stopLossLimit + order.orderClass = AlpacaOrder.OrderClass.BRACKET + + return order + +class OneCancelsOtherOrder(AlpacaOrder): + """This is a set of two orders with the same side (buy/buy or sell/sell) and + currently only exit order is supported. + """ + def __init__(self, action, instrument, takeProfitLimit, stopLossStop, stopLossLimit = None): + super(OneCancelsOtherOrder, self).__init__( + type_ = AlpacaOrder.Type.LIMIT, # OCO orders must be placed as limit orders + action = action, + instrument = instrument, + takeProfitLimit = takeProfitLimit, + stopLossStop = stopLossStop, + stopLossLimit = stopLossLimit + ) + +class OneTriggersOther(AlpacaOrder): + """OTO (One-Triggers-Other) is a variant of bracket order. + It takes one of the take-profit or stop-loss order in addition to the entry order. + """ + def __init__(self, action, instrument, takeProfitLimit, stopLossStop, stopLossLimit = None): + super(OneCancelsOtherOrder, self).__init__( + type_ = AlpacaOrder.Type.LIMIT, # OCO orders must be placed as limit orders + action = action, + instrument = instrument, + takeProfitLimit = takeProfitLimit, + stopLossStop = stopLossStop, + stopLossLimit = stopLossLimit + ) + + +class OneTriggersOther(AlpacaOrder): + """OTO (One-Triggers-Other) is a variant of bracket order. + It takes one of the take-profit or stop-loss order in addition to the entry order. + + Args: + openingOrder (AlpacaOrder): The opening order for the bracket order. + takeProfitLimit (numeric, Optional): The limit price for the exiting take profit limit order. + stopLossStop (numeric, Optional): The price to trigger the exiting stop loss order. + stopLossLimit (numeric, Optional): The limit price for the exiting stop loss order + if the stop loss order is a limit order. + + Returns: + AlpacaOrder: An OTO order. + """ + + def __new__(cls, openingOrder, takeProfitLimit = None, stopLossStop = None, stopLossLimit = None): + + # check exiting order conditions + if openingOrder.extendedHours: + raise Exception( + 'Extended hours are not supported for bracket orders, ' + \ + 'converting to regular hours order.' + ) + + if openingOrder.TimeInForce not in [AlpacaOrder.TimeInForce.DAY, AlpacaOrder.TimeInForce.GTC]: + raise Exception( + 'Time in force must be "day" or "gtc".' + ) + + assert (takeProfitLimit or stopLossStop) is not None, \ + 'One of takeProfitLimit or stopLossStop must be present' + + order = openingOrder + order.takeProfitLimit = takeProfitLimit + order.stopLossStop = stopLossStop + order.stopLossLimit = stopLossLimit + order.orderClass = AlpacaOrder.OrderClass.OTO + + return order + + + +# Order constructers +def fromAlpacaOrder(alpacaOrderEntity): + # https://alpaca.markets/docs/api-documentation/api-v2/orders/#order-entity + order = AlpacaOrder( + type_ = AlpacaOrder.Type.fromString(alpacaOrderEntity['type']), + action = AlpacaOrder.Action.fromString(alpacaOrderEntity['side']), + instrument = alpacaOrderEntity['symbol'], + quantity = alpacaOrderEntity['qty'], + instrumentTraits = broker.IntegerTraits(), + orderId = alpacaOrderEntity['order_id'], + clientOrderId = alpacaOrderEntity['client_order_id'], + createdAt = alpacaOrderEntity['created_at'], + updatedAt = alpacaOrderEntity['updated_at'], + submittedAt = alpacaOrderEntity['submitted_at'], + filledAt = alpacaOrderEntity['filled_at'], + expiredAt = alpacaOrderEntity['expired_at'], + canceledAt = alpacaOrderEntity['canceled_at'], + failedAt = alpacaOrderEntity['failed_at'], + replacedAt = alpacaOrderEntity['replaced_at'], + replacedBy = alpacaOrderEntity['replaced_by'], + replaces = alpacaOrderEntity['replaces'], + assetId = alpacaOrderEntity['asset_id'], + assetClass = alpacaOrderEntity['asset_class'], + filledQuantity = alpacaOrderEntity['filled_qty'], + filledAveragePrice = alpacaOrderEntity['filled_avg_price'], + orderClass = alpacaOrderEntity['order_class'], + timeInForce = AlpacaOrder.TimeInForce.fromString(alpacaOrderEntity['time_in_force']), + limitPrice = alpacaOrderEntity['limit_price'], + stopPrice = alpacaOrderEntity['stop_price'], + extendedHours = alpacaOrderEntity['extended_hours'], + legs = alpacaOrderEntity['legs'], + trailPercent = alpacaOrderEntity['trail_percent'], + trailPrice = alpacaOrderEntity['trail_price'], + hwm = alpacaOrderEntity['hwm'], + takeProfit = alpacaOrderEntity['take_profit']['limit_price'], + stopLossStop = alpacaOrderEntity['stop_loss']['stop_price'], + stopLossLimit = alpacaOrderEntity['stop_loss']['limit_price'] + ) + return order + +def toAlpacaOrder(order): + """ order should be a AlpacaOrder object. + + see https://alpaca.markets/docs/api-documentation/api-v2/orders/#order-entity + for details. + """ - # END broker.Broker interface + alpacaOrder = { + 'order_id': order.orderId, + 'symbol': order.instrument, + 'qty': order.quantity, + 'notional': order.notional, + 'side': AlpacaOrder.Action.toString(order.action), + 'type': AlpacaOrder.Type.toString(order.type_), + 'time_in_force': AlpacaOrder.TimeInForce.toString(order.timeInForce), + 'limit_price': order.limitPrice, + 'stop_price': order.stopPrice, + 'trail_price': order.trailPrice, + 'trail_percent': order.trailPercent, + 'extended_hours': order.extendedHours, + 'client_order_id': order.clientOrderId, + 'order_class': AlpacaOrder.OrderClass.toString(order.orderClass), + 'take_profit': {'limit_price': order.takeProfit}, + 'stop_loss':{'stop_price': order.stopLossStop} + } + + # for bracket / OTO orders + if order.takeProfit is not None: + alpacaOrder['take_profit'] = {'limit_price': order.takeProfit} + if order.stopLossStop is not None: + alpacaOrder['stop_loss'] = {'stop_price': order.stopLossStop} + if order.stopLossLimit is not None: + alpacaOrder['stop_loss']['limit_price'] = order.stopLossLimit + + # omit items that are None + alpacaOrder = {k: v for k, v in alpacaOrder.items() if v is not None} + + return alpacaOrder + + +class AlpacaOrderEvent(broker.OrderEvent): + """Adds Alpaca specific order states to broker.OrderEvent. + """ + class Type(AlpacaOrder.State): + # use order states + pass \ No newline at end of file diff --git a/pyalgotrade/alpaca/livefeed.py b/pyalgotrade/alpaca/livefeed.py index f874f06b6..deeb88975 100644 --- a/pyalgotrade/alpaca/livefeed.py +++ b/pyalgotrade/alpaca/livefeed.py @@ -15,253 +15,224 @@ # limitations under the License. """ -.. moduleauthor:: Gabriel Martin Becedillas Ruiz +.. moduleauthor:: Robert Lee + Splits out LiveFeed to allow for both live bar feeds and live trade feeds. + """ +import abc import datetime import time +import queue +import threading -from six.moves import queue - -from pyalgotrade import bar -from pyalgotrade import barfeed -from pyalgotrade import observer -from pyalgotrade.bitstamp import common -from pyalgotrade.bitstamp import wsclient - - -class TradeBar(bar.Bar): - # Optimization to reduce memory footprint. - __slots__ = ('__dateTime', '__tradeId', '__price', '__amount') - - def __init__(self, dateTime, trade): - self.__dateTime = dateTime - self.__tradeId = trade.getId() - self.__price = trade.getPrice() - self.__amount = trade.getAmount() - self.__buy = trade.isBuy() - - def __setstate__(self, state): - (self.__dateTime, self.__tradeId, self.__price, self.__amount) = state - - def __getstate__(self): - return (self.__dateTime, self.__tradeId, self.__price, self.__amount) - - def setUseAdjustedValue(self, useAdjusted): - if useAdjusted: - raise Exception("Adjusted close is not available") - - def getTradeId(self): - return self.__tradeId - - def getFrequency(self): - return bar.Frequency.TRADE - - def getDateTime(self): - return self.__dateTime - - def getOpen(self, adjusted=False): - return self.__price - - def getHigh(self, adjusted=False): - return self.__price - - def getLow(self, adjusted=False): - return self.__price - - def getClose(self, adjusted=False): - return self.__price - - def getVolume(self): - return self.__amount - - def getAdjClose(self): - return None - - def getTypicalPrice(self): - return self.__price - - def getPrice(self): - return self.__price - - def getUseAdjValue(self): - return False - - def isBuy(self): - return self.__buy - - def isSell(self): - return not self.__buy +import zmq +import json +from pyalgotrade import feed +from pyalgotrade import dataseries -class LiveTradeFeed(barfeed.BaseBarFeed): +from pyalgotrade.dataseries import bards, tradeds, quoteds +from pyalgotrade.alpaca import common - """A real-time BarFeed that builds bars from live trades. - :param maxLen: The maximum number of values that the :class:`pyalgotrade.dataseries.bards.BarDataSeries` will hold. - Once a bounded length is full, when new items are added, a corresponding number of items are discarded - from the opposite end. If None then dataseries.DEFAULT_MAX_LEN is used. - :type maxLen: int. - - .. note:: - Note that a Bar will be created for every trade, so open, high, low and close values will all be the same. +class LiveFeed(threading.Thread): + """A thread that takes incoming messages from Alapaca and publishes them on ZMQ sockets. """ - QUEUE_TIMEOUT = 0.01 - - def __init__(self, maxLen=None): - super(LiveTradeFeed, self).__init__(bar.Frequency.TRADE, maxLen) - self.__barDicts = [] - self.registerInstrument(common.btc_symbol) - self.__prevTradeDateTime = None - self.__thread = None - self.__wsClientConnected = False - self.__enableReconnection = True + def __init__(self, publishing_address, api_key_id = None, api_secret_key = None): + + # Create sockets + self.__zmq_context = zmq.Context.Instance() + + # for publishing data from the websocket + # (from Alpaca's stream.TradingStream and stream.DataStream) + self._publishing_address = publishing_address + self.__socket = self.__zmq_context.socket(zmq.PUB) + self.__socket.bind(self._publishing_address) + common.logger.info( + 'Live feed publishing data at {self._publishing_address}' + ) + + # threading stuff + self.__stop = False self.__stopped = False - self.__orderBookUpdateEvent = observer.Event() - - # Factory method for testing purposes. - def buildWebSocketClientThread(self): - return wsclient.WebSocketClientThread() - - def getCurrentDateTime(self): - return wsclient.get_current_datetime() - - def enableReconection(self, enableReconnection): - self.__enableReconnection = enableReconnection - def __initializeClient(self): - common.logger.info("Initializing websocket client.") - assert self.__wsClientConnected is False, "Websocket client already connected" - - try: - # Start the thread that runs the client. - self.__thread = self.buildWebSocketClientThread() - self.__thread.start() - except Exception as e: - common.logger.exception("Error connecting : %s" % str(e)) - - # Wait for initialization to complete. - while not self.__wsClientConnected and self.__thread.is_alive(): - self.__dispatchImpl([wsclient.WebSocketClient.Event.CONNECTED]) - - if self.__wsClientConnected: - common.logger.info("Initialization ok.") - else: - common.logger.error("Initialization failed.") - return self.__wsClientConnected - - def __onConnected(self): - self.__wsClientConnected = True - - def __onDisconnected(self): - self.__wsClientConnected = False - - if self.__enableReconnection: - initialized = False - while not self.__stopped and not initialized: - common.logger.info("Reconnecting") - initialized = self.__initializeClient() - if not initialized: - time.sleep(5) - else: - self.__stopped = True - - def __dispatchImpl(self, eventFilter): - ret = False + # make connection + stream = common.make_connection( + connection_type = 'stream', api_key_id = api_key_id, api_secret_key = api_secret_key) + + # Threading stuff + def start(self): + pass + + def run(self): + pass + + def stop(self): + self.__stop = True + + # Functions to handle incoming messages + def publish(self, topic, messages): + for message in messages: + self.__socket.send_multipart([topic.encode(), message]) + + def publish_with_topic(self, topic): + return lambda message: self.publish(topic.encode(), message) + + # subscribe to real time data + def subscribe_trade_updates(self): + self.stream.subscribe_trade_updates(self.publish_with_topic('BROKER')) + + def subscribe_trades(self, *symbols, handler_cancel_errors = None, handler_corrections = None): + self.stream.subscribe_trades(self.publish_with_topic('TRADES'), symbols, handler_cancel_errors, handler_corrections) + + def subscribe_quotes(self, *symbols): + self.stream.subscribe_quotes(self.publish_with_topic('QUOTES'), symbols) + + def subscribe_bars(self, *symbols): + self.stream.subscribe_bars(self.publish_with_topic('BARS'), symbols) + + def subscribe_dailiy_bars(self, *symbols): + self.stream.subscribe_daily_bars(self.publish_with_topic('BARS'), symbols) + + def subscribe_statuses(self, *symbols): + self.stream.subscribe_statuses(self.publish_with_topic('STATUSES'), symbols) + + def subscribe_lulds(self, *symbols): + self.stream.subscribe_lulds(self.publish_with_topic('LULDS'), symbols) + + def subscribe_crypto_trades(self, *symbols): + self.stream.subscribe_crypto_trades(self.publish_with_topic('TRADES'), symbols) + + def subscribe_crypto_quotes(self, *symbols): + self.stream.subscribe_crypto_quotes(self.publish_with_topic('QUOTES'), symbols) + + def subscribe_crypto_bars(self, *symbols): + self.stream.subscribe_crypto_bars(self.publish_with_topic('BARS'), symbols) + + def subscribe_crypto_daily_bars(self, *symbols): + self.stream.subscribe_crypto_daily_bars(self.publish_with_topic('BARS'), symbols) + +class EventQueuer(threading.Thread): + """A thread that checks a ZMQ SUB socket for streaming data. + """ + POLL_FREQUENCY = 0.5 + + def __init__(self, liveFeedAddress, topic): + super(EventQueuer, self).__init__() + + self.__zmq_context = zmq.Context.Instance() + self.__event_socket = self.__zmq_context.socket(zmq.SUB) + self.__event_socket.setsockopt(zmq.SUBSCRIBE, str(topic).encode()) + self.__event_socket.connect(liveFeedAddress) + self.__queue = queue.Queue() + self.__stop = False + + def _getNewEvent(self): try: - eventType, eventData = self.__thread.getQueue().get(True, LiveTradeFeed.QUEUE_TIMEOUT) - if eventFilter is not None and eventType not in eventFilter: - return False - - ret = True - if eventType == wsclient.WebSocketClient.Event.TRADE: - self.__onTrade(eventData) - elif eventType == wsclient.WebSocketClient.Event.ORDER_BOOK_UPDATE: - self.__orderBookUpdateEvent.emit(eventData) - elif eventType == wsclient.WebSocketClient.Event.CONNECTED: - self.__onConnected() - elif eventType == wsclient.WebSocketClient.Event.DISCONNECTED: - self.__onDisconnected() + update = self.__data_socket.recv_multipart(zmq.NOBLOCK) + update = update['data'] + return update + except zmq.ZMQERROR as exc: + if exc.errno == zmq.EAGAIN: + # nothing to get + return else: - ret = False - common.logger.error("Invalid event received to dispatch: %s - %s" % (eventType, eventData)) - except queue.Empty: - pass - return ret + raise - # Bar datetimes should not duplicate. In case trade object datetimes conflict, we just move one slightly forward. - def __getTradeDateTime(self, trade): - ret = trade.getDateTime() - if ret == self.__prevTradeDateTime: - ret += datetime.timedelta(microseconds=1) - self.__prevTradeDateTime = ret - return ret + def getQueue(self): + return self.__queue + + def start(self): + if (newEvent:= self._getNewEvent()): + self.__queue.put(newEvent) + common.logger.info('New Event: {newEvent}') + super(EventQueuer, self).start() + + def run(self): + while not self.__stop: + try: + if (newEvent:= self._getNewTrade()): + self.__queue.put(newEvent) + common.logger.info('New Event: {newEvent}') + else: + time.sleep(EventQueuer.POLL_FREQUENCY) + except Exception as e: + common.logger.critical("Error retrieving new events", exc_info=e) - def __onTrade(self, trade): - # Build a bar for each trade. - barDict = { - common.btc_symbol: TradeBar(self.__getTradeDateTime(trade), trade) - } - self.__barDicts.append(barDict) + def stop(self): + self.__stop = True - def barsHaveAdjClose(self): - return False +class BaseLiveDataFeed(feed.BaseFeed): + + QUEUE_TIMEOUT = 0.01 - def getNextBars(self): + def __init__(self, liveFeedAddress, topic, maxLen = None): + super(BaseLiveDataFeed, self).__init__(maxLen) + + # Queue to get data from + self.__dataQueuer = EventQueuer(liveFeedAddress, topic) + + # keep track of most recent data + self.__currentData = None + self.__lastData = None + + # BEGIN feed.BaseFeed interface + def reset(self): + self.__currentData = None + self.__lastData = {} + super(BaseLiveDataFeed, self).reset() + + @abc.abstrctmethod + def createDataSeries(self, key, maxLen): + pass + + def getNextValues(self): + # from barfeed.BaseBarFeed.getNextValues + dateTime = None + data = self.getNextData() + if data is not None: + dateTime = data.getDateTime + + self.__currentData = data + for instrument in data.getInstruments(): + self.__lastData[instrument] = data[instrument] + + return (dateTime, data) + + # END feed.BaseFeed interface + + def getNextData(self): ret = None - if len(self.__barDicts): - ret = bar.Bars(self.__barDicts.pop(0)) + try: + ret = self.__dataQueuer.getQueue(block = True, timeout = BaseLiveDataFeed.QUEUE_TIMEOUT) + return ret + except: + return False + +class LiveBarFeed(BaseLiveDataFeed): + def __init__(self, liveFeedAddress, maxLen = None): + super(LiveBarFeed, self).__init__(liveFeedAddress, 'BARS') + + def createDataSeries(self, key, maxLen): + ret = bards.BarDataSeries(maxLen) + # real time objects do not use adjusted values + ret.setUseAdjustedValues(False) return ret - def peekDateTime(self): - # Return None since this is a realtime subject. - return None - - # This may raise. - def start(self): - super(LiveTradeFeed, self).start() - if self.__thread is not None: - raise Exception("Already running") - elif not self.__initializeClient(): - self.__stopped = True - raise Exception("Initialization failed") - - def dispatch(self): - # Note that we may return True even if we didn't dispatch any Bar - # event. - ret = False - if self.__dispatchImpl(None): - ret = True - if super(LiveTradeFeed, self).dispatch(): - ret = True +class LiveTradeFeed(BaseLiveDataFeed): + def __init__(self, liveFeedAddress, maxLen = None): + super(LiveBarFeed, self).__init__(liveFeedAddress, 'TRADES') + + def createDataSeries(self, key, maxLen): + ret = tradeds.TradeDataSeries(maxLen) return ret - # This should not raise. - def stop(self): - try: - self.__stopped = True - if self.__thread is not None and self.__thread.is_alive(): - common.logger.info("Shutting down websocket client.") - self.__thread.stop() - except Exception as e: - common.logger.error("Error shutting down client: %s" % (str(e))) - - # This should not raise. - def join(self): - if self.__thread is not None: - self.__thread.join() - - def eof(self): - return self.__stopped - - def getOrderBookUpdateEvent(self): - """ - Returns the event that will be emitted when the orderbook gets updated. - - Eventh handlers should receive one parameter: - 1. A :class:`pyalgotrade.bitstamp.wsclient.OrderBookUpdate` instance. - - :rtype: :class:`pyalgotrade.observer.Event`. - """ - return self.__orderBookUpdateEvent +class LiveQuoteFeed(BaseLiveDataFeed): + def __init__(self, liveFeedAddress, maxLen = None): + super(LiveBarFeed, self).__init__(liveFeedAddress, 'QUOTES') + + def createDataSeries(self, key, maxLen): + ret = quoteds.QuoteDataSeries(maxLen) + return ret diff --git a/pyalgotrade/alpaca/wsclient.py b/pyalgotrade/alpaca/wsclient.py index 0e7a4c67a..1951d0a0e 100644 --- a/pyalgotrade/alpaca/wsclient.py +++ b/pyalgotrade/alpaca/wsclient.py @@ -34,7 +34,10 @@ def get_current_datetime(): class Trade(pusher.Event): - """A trade event.""" + """A trade event. + + # TODO: use broker specific formats + """ def __init__(self, dateTime, eventDict): super(Trade, self).__init__(eventDict, True) @@ -42,7 +45,9 @@ def __init__(self, dateTime, eventDict): def getDateTime(self): """Returns the :class:`datetime.datetime` when this event was received.""" - return self.__dateTime + # return self.__dateTime + # TODO: use event timestamp + raise NotImplementedError() def getId(self): """Returns the trade id.""" @@ -66,7 +71,10 @@ def isSell(self): class OrderBookUpdate(pusher.Event): - """An order book update event.""" + """An order book update event. + + # TODO: use broker specific formats + """ def __init__(self, dateTime, eventDict): super(OrderBookUpdate, self).__init__(eventDict, True) @@ -92,6 +100,39 @@ def getAskVolumes(self): """Returns a list with the top 20 ask volumes.""" return [float(ask[1]) for ask in self.getData()["asks"]] +class Bar(pusher.Event): + """A bar event. + + # TODO: use broker specific formats + """ + + def __init__(self, dateTime, eventDict): + super(OrderBookUpdate, self).__init__(eventDict, True) + self.__dateTime = dateTime + + def getDateTime(self): + """Returns the :class:`datetime.datetime` when this event was received.""" + # return self.__dateTime + raise NotImplementedError() + + def getOpen(self): + raise NotImplementedError() + + def getClose(self): + raise NotImplementedError() + + def getHigh(self): + raise NotImplementedError() + + def getLow(self): + raise NotImplementedError() + + def getVolume(self): + raise NotImplementedError() + + def getFrequency(self): + raise NotImplementedError() + class WebSocketClient(pusher.WebSocketClient): """ @@ -106,9 +147,10 @@ class Event: ORDER_BOOK_UPDATE = 2 CONNECTED = 3 DISCONNECTED = 4 + BAR = 5 def __init__(self, queue): - super(WebSocketClient, self).__init__(WebSocketClient.PUSHER_APP_KEY, 5) + super(WebSocketClient, self).__init__(WebSocketClient.PUSHER_APP_KEY, protocol = 5) self.__queue = queue def onMessage(self, msg): @@ -155,7 +197,7 @@ def onUnknownEvent(self, event): common.logger.warning("Unknown event: %s" % (event)) ###################################################################### - # Bitstamp specific + # Broker specific def onTrade(self, trade): self.__queue.put((WebSocketClient.Event.TRADE, trade)) diff --git a/pyalgotrade/bar.py b/pyalgotrade/bar.py index cd9e13cd9..92c6b82b1 100644 --- a/pyalgotrade/bar.py +++ b/pyalgotrade/bar.py @@ -38,6 +38,7 @@ class Frequency(object): # It is important for frequency values to get bigger for bigger windows. TRADE = -1 + QUOTE = -1 SECOND = 1 MINUTE = 60 HOUR = 60*60 @@ -241,7 +242,6 @@ def getPrice(self): def getExtraColumns(self): return self.__extra - class Bars(object): """A group of :class:`Bar` objects. diff --git a/pyalgotrade/dataseries/quoteds.py b/pyalgotrade/dataseries/quoteds.py new file mode 100644 index 000000000..64147cb86 --- /dev/null +++ b/pyalgotrade/dataseries/quoteds.py @@ -0,0 +1,95 @@ +""" +.. moduleauthor:: Robert Lee +""" + +from pyalgotrade import dataseries + +import six + + +class QuoteDataSeries(dataseries.SequenceDataSeries): + """A DataSeries of :class:`pyalgotrade.bar.Quote` instances. + + :param maxLen: The maximum number of values to hold. + Once a bounded length is full, when new items are added, a corresponding number of items are discarded from the + opposite end. If None then dataseries.DEFAULT_MAX_LEN is used. + :type maxLen: int. + """ + + def __init__(self, maxLen=None): + super(QuoteDataSeries, self).__init__(maxLen) + self.__askExchangeDS = dataseries.SequenceDataSeries(maxLen) + self.__askPriceDS = dataseries.SequenceDataSeries(maxLen) + self.__askSizeDS = dataseries.SequenceDataSeries(maxLen) + self.__bidExchangeDS = dataseries.SequenceDataSeries(maxLen) + self.__bidPriceDS = dataseries.SequenceDataSeries(maxLen) + self.__bidSizeDS = dataseries.SequenceDataSeries(maxLen) + self.__quoteConditionDS = dataseries.SequenceDataSeries(maxLen) + self.__tapeDS = dataseries.SequenceDataSeries(maxLen) + self.__extraDS = {} + + def __getOrCreateExtraDS(self, name): + ret = self.__extraDS.get(name) + if ret is None: + ret = dataseries.SequenceDataSeries(self.getMaxLen()) + self.__extraDS[name] = ret + return ret + + def append(self, trade): + self.appendWithDateTime(trade.getDateTime(), trade) + + def appendWithDateTime(self, dateTime, quote): + assert(dateTime is not None) + assert(quote is not None) + + super(QuoteDataSeries, self).appendWithDateTime(dateTime, quote) + + self.__askExchangeDS.appendWithDateTime(dateTime, quote.getAskExchange()) + self.__askPriceDS.appendWithDateTime(dateTime, quote.getAskPrice()) + self.__askSizeDS.appendWithDateTime(dateTime, quote.getAskSize()) + self.__bidExchangeDS.appendWithDateTime(dateTime, quote.getBidExchange()) + self.__bidPriceDS.appendWithDateTime(dateTime, quote.getBidPrice()) + self.__bidSizeDS.appendWithDateTime(dateTime, quote.getBidSize()) + self.__quoteConditionDS.appendWithDateTime(dateTime, quote.getQuoteCondition()) + self.__tapeDS.appendWithDateTime(dateTime, quote.getTape()) + + # Process extra columns. + for name, value in six.iteritems(quote.getExtraColumns()): + extraDS = self.__getOrCreateExtraDS(name) + extraDS.appendWithDateTime(dateTime, value) + + def getAskExchangeDataSeries(self): + """Returns a :class:`pyalgotrade.dataseries.DataSeries` with the ask exchanges.""" + return self.__askExchangeDS + + def getAskPriceDataSeries(self): + """Returns a :class:`pyalgotrade.dataseries.DataSeries` with the ask prices.""" + return self.__askPriceDS + + def getAskSizeDataSeries(self): + """Returns a :class:`pyalgotrade.dataseries.DataSeries` with the ask sizes.""" + return self.__askSizeDS + + def getBidExchangeDataSeries(self): + """Returns a :class:`pyalgotrade.dataseries.DataSeries` with the bid exchanges.""" + return self.__bidExchangeDS + + def getBidPriceDataSeries(self): + """Returns a :class:`pyalgotrade.dataseries.DataSeries` with the bid prices.""" + return self.__bidPriceDS + + def getBidSizeDataSeries(self): + """Returns a :class:`pyalgotrade.dataseries.DataSeries` with the bid sizes.""" + return self.__bidSizeDS + + def getQuoteConditionDataSeries(self): + """Returns a :class:`pyalgotrade.dataseries.DataSeries` with the quote conditions.""" + return self.__quoteConditionDS + + def getTapeDataSeries(self): + """Returns a :class:`pyalgotrade.dataseries.DataSeries` with the quote tapes.""" + return self.__tapeDS + + def getExtraDataSeries(self, name): + """Returns a :class:`pyalgotrade.dataseries.DataSeries` for an extra column.""" + return self.__getOrCreateExtraDS(name) \ No newline at end of file diff --git a/pyalgotrade/dataseries/tradeds.py b/pyalgotrade/dataseries/tradeds.py new file mode 100644 index 000000000..097185a36 --- /dev/null +++ b/pyalgotrade/dataseries/tradeds.py @@ -0,0 +1,71 @@ +""" +.. moduleauthor:: Robert Lee +""" + +from pyalgotrade import dataseries + +import six + + +class TradeDataSeries(dataseries.SequenceDataSeries): + """A DataSeries of :class:`pyalgotrade.bar.Trade` instances. + + :param maxLen: The maximum number of values to hold. + Once a bounded length is full, when new items are added, a corresponding number of items are discarded from the + opposite end. If None then dataseries.DEFAULT_MAX_LEN is used. + :type maxLen: int. + """ + + def __init__(self, maxLen=None): + super(TradeDataSeries, self).__init__(maxLen) + self.__tradeIdDS = dataseries.SequenceDataSeries(maxLen) + self.__priceDS = dataseries.SequenceDataSeries(maxLen) + self.__sizeDS = dataseries.SequenceDataSeries(maxLen) + self.__isBuyDS = dataseries.SequenceDataSeries(maxLen) + self.__extraDS = {} + + def __getOrCreateExtraDS(self, name): + ret = self.__extraDS.get(name) + if ret is None: + ret = dataseries.SequenceDataSeries(self.getMaxLen()) + self.__extraDS[name] = ret + return ret + + def append(self, trade): + self.appendWithDateTime(trade.getDateTime(), trade) + + def appendWithDateTime(self, dateTime, trade): + assert(dateTime is not None) + assert(trade is not None) + + super(TradeDataSeries, self).appendWithDateTime(dateTime, trade) + + self.__tradeIdDS.appendWithDateTime(dateTime, trade.getTradeId()) + self.__priceDS.appendWithDateTime(dateTime, trade.getPrice()) + self.__sizeDS.appendWithDateTime(dateTime, trade.getSize()) + self.__isBuyDS.appendWithDateTime(dateTime, trade.getIsBuy()) + + # Process extra columns. + for name, value in six.iteritems(trade.getExtraColumns()): + extraDS = self.__getOrCreateExtraDS(name) + extraDS.appendWithDateTime(dateTime, value) + + def getTradeIdDataSeries(self): + """Returns a :class:`pyalgotrade.dataseries.DataSeries` with the trade Ids.""" + return self.__tradeIdDs + + def getPriceDataSeries(self): + """Returns a :class:`pyalgotrade.dataseries.DataSeries` with the trade prices.""" + return self.__priceDS + + def getSizeDataSeries(self): + """Returns a :class:`pyalgotrade.dataseries.DataSeries` with the trade sizes.""" + return self.__sizeDS + + def getIsBuyDataSeries(self): + """Returns a :class:`pyalgotrade.dataseries.DataSeries` with whether the trades are buys.""" + return self.__isBuyDS + + def getExtraDataSeries(self, name): + """Returns a :class:`pyalgotrade.dataseries.DataSeries` for an extra column.""" + return self.__getOrCreateExtraDS(name) diff --git a/pyalgotrade/quote.py b/pyalgotrade/quote.py new file mode 100644 index 000000000..97a302ed0 --- /dev/null +++ b/pyalgotrade/quote.py @@ -0,0 +1,158 @@ +""" +.. moduleauthor:: Robert Lee +""" + +from pyalgotrade import bar + +Frequency = bar.Frequency + +class Quote(object): + # Optimization to reduce memory footprint. + __slots__ = ( + '__dateTime', + '__ask_exchange', + '__ask_price', + '__ask_size', + '__bid_exchange', + '__bid_price', + '__bid_size', + '__quote_condition', + '__tape', + '__extra' + ) + + def __init__(self, dateTime, + askExchange, askPrice, askSize, + bidExchange, bidPrice, bidSize, + quoteCondition, tape, extra = {}): + + self.__dateTime = dateTime + self.__askExchange = askExchange + self.__askPrice = askPrice + self.__askSize = askSize + self.__bidExchange = bidExchange + self.__bidPrice = bidPrice + self.__bidSize = bidSize + self.__quoteCondition = quoteCondition + self.__tape = tape + self.__extra = extra + + def __setstate__(self, state): + (self.__dateTime, + self.__askExchange, + self.__askPrice, + self.__askSize, + self.__bidExchange, + self.__bidPrice, + self.__bidSize, + self.__quoteCondition, + self.__tape, + self.__extra + ) = state + + def __getstate__(self): + return ( + self.__dateTime, + self.__askExchange, + self.__askPrice, + self.__askSize, + self.__bidExchange, + self.__bidPrice, + self.__bidSize, + self.__quoteCondition, + self.__tape, + self.__extra + ) + + def getFrequency(self): + return Frequency.QUOTE + + def getDateTime(self): + return self.__dateTime + + def getAskExchange(self): + return self.__askExchange + + def getAskPrice(self): + return self.__askPrice + + def getAskSize(self): + return self.__askSize + + def getBidExchange(self): + return self.__bidExchange + + def getBidPrice(self): + return self.__bidPrice + + def getBidSize(self): + return self.__bidSize + + def getQuoteCondition(self): + return self.__quoteCondition + + def getTape(self): + return self.__tape + + def getExtraColumns(self): + return self.__extra + +class Quotes(object): + + """A group of :class:`Quote` objects. + + :param quoteDict: A map of instrument to :class:`Quote` objects. + :type quoteDict: map. + + .. note:: + All bars must have the same datetime. + """ + + def __init__(self, quoteDict): + if len(quoteDict) == 0: + raise Exception("No quotes supplied") + + # Check that bar datetimes are in sync + firstDateTime = None + firstInstrument = None + for instrument, currentQuote in six.iteritems(quoteDict): + if firstDateTime is None: + firstDateTime = currentQuote.getDateTime() + firstInstrument = instrument + elif currentQuote.getDateTime() != firstDateTime: + raise Exception("Quote data times are not in sync. %s %s != %s %s" % ( + instrument, + currentQuote.getDateTime(), + firstInstrument, + firstDateTime + )) + + self.__quoteDict = quoteDict + self.__dateTime = firstDateTime + + def __getitem__(self, instrument): + """Returns the :class:`pyalgoquote.bar.Quote` for the given instrument. + If the instrument is not found an exception is raised.""" + return self.__quoteDict[instrument] + + def __contains__(self, instrument): + """Returns True if a :class:`pyalgoquote.bar.Quote` for the given instrument is available.""" + return instrument in self.__quoteDict + + def items(self): + return list(self.__quoteDict.items()) + + def keys(self): + return list(self.__quoteDict.keys()) + + def getInstruments(self): + """Returns the instrument symbols.""" + return list(self.__quoteDict.keys()) + + def getDateTime(self): + """Returns the :class:`datetime.datetime` for this set of quotes.""" + return self.__dateTime + + def getQuote(self, instrument): + """Returns the :class:`pyalgoquote.bar.Quote` for the given instrument or None if the instrument is not found.""" + return self.__quoteDict.get(instrument, None) \ No newline at end of file diff --git a/pyalgotrade/trade.py b/pyalgotrade/trade.py new file mode 100644 index 000000000..e4f9f138b --- /dev/null +++ b/pyalgotrade/trade.py @@ -0,0 +1,107 @@ +""" +.. moduleauthor:: Robert Lee +""" + +from pyalgotrade import bar + +Frequency = bar.Frequency + +class Trade(object): + + # Optimization to reduce memory footprint. + __slots__ = ('__dateTime', '__tradeId', '__price', '__size', '__isBuy') + + def __init__(self, dateTime, tradeId, price, size, isBuy, extra = {}): + self.__dateTime = dateTime + self.__tradeId = tradeId + self.__price = price + self.__size = size + self.__isBuy = isBuy + self.__extra = extra + + def __setstate__(self, state): + (self.__dateTime, self.__tradeId, self.__price, self.__amount, self.__extra) = state + + def __getstate__(self): + return (self.__dateTime, self.__tradeId, self.__price, self.__amount, self.__extra) + + def getFrequency(self): + return Frequency.TRADE + + def getDateTime(self): + return self.__dateTime + + def getTradeId(self): + return self.__tradeId + + def getPrice(self): + return self.__price + + def getSize(self): + return self.__size + + def getIsBuy(self): + return self.__isBuy + + def getExtraColumns(self): + return self.__extra + +class Trades(object): + + """A group of :class:`Trade` objects. + + :param tradeDict: A map of instrument to :class:`Trade` objects. + :type tradeDict: map. + + .. note:: + All bars must have the same datetime. + """ + + def __init__(self, tradeDict): + if len(tradeDict) == 0: + raise Exception("No trades supplied") + + # Check that bar datetimes are in sync + firstDateTime = None + firstInstrument = None + for instrument, currentTrade in six.iteritems(tradeDict): + if firstDateTime is None: + firstDateTime = currentTrade.getDateTime() + firstInstrument = instrument + elif currentTrade.getDateTime() != firstDateTime: + raise Exception("Trade data times are not in sync. %s %s != %s %s" % ( + instrument, + currentTrade.getDateTime(), + firstInstrument, + firstDateTime + )) + + self.__tradeDict = tradeDict + self.__dateTime = firstDateTime + + def __getitem__(self, instrument): + """Returns the :class:`pyalgotrade.bar.Trade` for the given instrument. + If the instrument is not found an exception is raised.""" + return self.__tradeDict[instrument] + + def __contains__(self, instrument): + """Returns True if a :class:`pyalgotrade.bar.Trade` for the given instrument is available.""" + return instrument in self.__tradeDict + + def items(self): + return list(self.__tradeDict.items()) + + def keys(self): + return list(self.__tradeDict.keys()) + + def getInstruments(self): + """Returns the instrument symbols.""" + return list(self.__tradeDict.keys()) + + def getDateTime(self): + """Returns the :class:`datetime.datetime` for this set of trades.""" + return self.__dateTime + + def getTrade(self, instrument): + """Returns the :class:`pyalgotrade.bar.Trade` for the given instrument or None if the instrument is not found.""" + return self.__tradeDict.get(instrument, None) From 310edfc2c6104db1167d7ede7801bc8751597406 Mon Sep 17 00:00:00 2001 From: jibi Date: Sat, 22 Jan 2022 14:19:54 -0800 Subject: [PATCH 7/7] More alpaca work --- pyalgotrade/alpaca/broker.py | 125 ---------------- pyalgotrade/alpaca/common.py | 33 +++-- pyalgotrade/alpaca/httpclient.py | 237 ------------------------------ pyalgotrade/alpaca/livefeed.py | 136 +++++++++++++----- pyalgotrade/alpaca/wsclient.py | 240 ------------------------------- pyalgotrade/quote.py | 15 +- pyalgotrade/trade.py | 60 ++++++-- 7 files changed, 186 insertions(+), 660 deletions(-) delete mode 100644 pyalgotrade/alpaca/broker.py delete mode 100644 pyalgotrade/alpaca/httpclient.py delete mode 100644 pyalgotrade/alpaca/wsclient.py diff --git a/pyalgotrade/alpaca/broker.py b/pyalgotrade/alpaca/broker.py deleted file mode 100644 index c4d93d809..000000000 --- a/pyalgotrade/alpaca/broker.py +++ /dev/null @@ -1,125 +0,0 @@ -# PyAlgoTrade -# -# Copyright 2011-2018 Gabriel Martin Becedillas Ruiz -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -.. moduleauthor:: Robert Lee - -Don't think we need a custom backtester for alpaca. -""" - - -from pyalgotrade import broker -from pyalgotrade.broker import backtesting -from pyalgotrade.bitstamp import common -from pyalgotrade.bitstamp import livebroker - - -LiveBroker = livebroker.LiveBroker - -# In a backtesting or paper-trading scenario the BacktestingBroker dispatches events while processing events from the -# BarFeed. -# It is guaranteed to process BarFeed events before the strategy because it connects to BarFeed events before the -# strategy. - - -# class BacktestingBroker(backtesting.Broker): -# MIN_TRADE_USD = 5 - -# """An Alpaca backtesting broker. - -# :param cash: The initial amount of cash. -# :type cash: int/float. -# :param barFeed: The bar feed that will provide the bars. -# :type barFeed: :class:`pyalgotrade.barfeed.BarFeed` -# :param fee: The fee percentage for each order. Defaults to 0.25%. -# :type fee: float. - -# .. note:: -# * Only limit orders are supported. -# * Orders are automatically set as **goodTillCanceled=True** and **allOrNone=False**. -# * BUY_TO_COVER orders are mapped to BUY orders. -# * SELL_SHORT orders are mapped to SELL orders. -# """ - -# def __init__(self, cash, barFeed, fee=0.0025): -# commission = backtesting.TradePercentage(fee) -# super(BacktestingBroker, self).__init__(cash, barFeed, commission) - -# # def getInstrumentTraits(self, instrument): -# # return common.BTCTraits() - -# def submitOrder(self, order): -# # if order.isInitial(): -# # # Override user settings based on Bitstamp limitations. -# # order.setAllOrNone(False) -# # order.setGoodTillCanceled(True) -# return super(BacktestingBroker, self).submitOrder(order) - -# # def createMarketOrder(self, action, instrument, quantity, onClose=False): -# # raise Exception("Market orders are not supported") - -# def createLimitOrder(self, action, instrument, limitPrice, quantity): -# if instrument != common.btc_symbol: -# raise Exception("Only BTC instrument is supported") - -# if action == broker.Order.Action.BUY_TO_COVER: -# action = broker.Order.Action.BUY -# elif action == broker.Order.Action.SELL_SHORT: -# action = broker.Order.Action.SELL - -# if limitPrice * quantity < BacktestingBroker.MIN_TRADE_USD: -# raise Exception("Trade must be >= %s" % (BacktestingBroker.MIN_TRADE_USD)) - -# if action == broker.Order.Action.BUY: -# # Check that there is enough cash. -# fee = self.getCommission().calculate(None, limitPrice, quantity) -# cashRequired = limitPrice * quantity + fee -# if cashRequired > self.getCash(False): -# raise Exception("Not enough cash") -# elif action == broker.Order.Action.SELL: -# # Check that there are enough coins. -# if quantity > self.getShares(common.btc_symbol): -# raise Exception("Not enough %s" % (common.btc_symbol)) -# else: -# raise Exception("Only BUY/SELL orders are supported") - -# return super(BacktestingBroker, self).createLimitOrder(action, instrument, limitPrice, quantity) - -# def createStopOrder(self, action, instrument, stopPrice, quantity): -# raise Exception("Stop orders are not supported") - -# def createStopLimitOrder(self, action, instrument, stopPrice, limitPrice, quantity): -# raise Exception("Stop limit orders are not supported") - - -# class PaperTradingBroker(BacktestingBroker): -# """A Bitstamp paper trading broker. - -# :param cash: The initial amount of cash. -# :type cash: int/float. -# :param barFeed: The bar feed that will provide the bars. -# :type barFeed: :class:`pyalgotrade.barfeed.BarFeed` -# :param fee: The fee percentage for each order. Defaults to 0.5%. -# :type fee: float. - -# .. note:: -# * Only limit orders are supported. -# * Orders are automatically set as **goodTillCanceled=True** and **allOrNone=False**. -# * BUY_TO_COVER orders are mapped to BUY orders. -# * SELL_SHORT orders are mapped to SELL orders. -# """ - -# pass diff --git a/pyalgotrade/alpaca/common.py b/pyalgotrade/alpaca/common.py index 2aba4b16a..fcef1cd99 100644 --- a/pyalgotrade/alpaca/common.py +++ b/pyalgotrade/alpaca/common.py @@ -18,6 +18,10 @@ .. moduleauthor:: Robert Lee """ import os +from datetime import datetime + +import msgpack +import pandas as pd import alpaca_trade_api as tradeapi from alpaca_trade_api.rest_async import AsyncRest @@ -29,7 +33,7 @@ logger = pyalgotrade.logger.getLogger("alpaca") -def make_connection(connection_type, api_key_id = None, api_secret_key = None): +def make_connection(connection_type, api_key_id = None, api_secret_key = None, live = False): """Makes a connection to Alpaca. https://alpaca.markets/docs/api-documentation/api-v2/ @@ -43,9 +47,13 @@ def make_connection(connection_type, api_key_id = None, api_secret_key = None): """ # credentials - api_key_id = api_key_id or os.environ.get('ALPACA_API_KEY_ID') - api_secret_key = api_secret_key or os.environ.get('ALPACA_API_SECRET_KEY') - + if live: + api_key_id = api_key_id or os.environ.get('ALPACA_API_KEY_ID') + api_secret_key = api_secret_key or os.environ.get('ALPACA_API_SECRET_KEY') + else: + api_key_id = api_key_id or os.environ.get('ALPACA_API_KEY_ID_PAPER') + api_secret_key = api_secret_key or os.environ.get('ALPACA_API_SECRET_KEY_PAPER') + if api_key_id is None: logger.error('Unable to retrieve API Key ID.') if api_key_id is None: @@ -60,9 +68,16 @@ def make_connection(connection_type, api_key_id = None, api_secret_key = None): return connection -# btc_symbol = "BTC" - +def json_serializer(obj): + if isinstance(obj, datetime): + return {'_isoformat': obj.isoformat()} + elif isinstance(obj, msgpack.ext.Timestamp): + return {'_unix_nano': obj.to_unix_nano()} + raise TypeError('...') -# class BTCTraits(broker.InstrumentTraits): -# def roundQuantity(self, quantity): -# return round(quantity, 8) +def json_deserializer(obj): + if (_isoformat := obj.get('_isoformat')) is not None: + return datetime.fromisoformat(_isoformat) + elif (_unix_nano := obj.get('_unix_nano')) is not None: + return pd.to_datetime(_unix_nano) + return obj \ No newline at end of file diff --git a/pyalgotrade/alpaca/httpclient.py b/pyalgotrade/alpaca/httpclient.py deleted file mode 100644 index be49fe73e..000000000 --- a/pyalgotrade/alpaca/httpclient.py +++ /dev/null @@ -1,237 +0,0 @@ -# PyAlgoTrade -# -# Copyright 2011-2018 Gabriel Martin Becedillas Ruiz -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -.. moduleauthor:: Robert Lee -""" - -import time -import datetime -import hmac -import hashlib -import requests -import threading - -from pyalgotrade.utils import dt -from pyalgotrade.alpaca import common - -import logging -logging.getLogger("requests").setLevel(logging.ERROR) - - -def parse_datetime(dateTime): - try: - ret = datetime.datetime.strptime(dateTime, "%Y-%m-%d %H:%M:%S") - except ValueError: - ret = datetime.datetime.strptime(dateTime, "%Y-%m-%d %H:%M:%S.%f") - return dt.as_utc(ret) - - -class NonceGenerator(object): - def __init__(self): - self.__prev = None - - def getNext(self): - ret = int(time.time()) - if self.__prev is not None and ret <= self.__prev: - ret = self.__prev + 1 - self.__prev = ret - return ret - - -class AccountBalance(object): - def __init__(self, jsonDict): - self.__jsonDict = jsonDict - - def getDict(self): - return self.__jsonDict - - def getUSDAvailable(self): - return float(self.__jsonDict["usd_available"]) - - def getBTCAvailable(self): - return float(self.__jsonDict["btc_available"]) - - -class Order(object): - def __init__(self, jsonDict): - self.__jsonDict = jsonDict - - def getDict(self): - return self.__jsonDict - - def getId(self): - return int(self.__jsonDict["id"]) - - def isBuy(self): - return self.__jsonDict["type"] == 0 - - def isSell(self): - return self.__jsonDict["type"] == 1 - - def getPrice(self): - return float(self.__jsonDict["price"]) - - def getAmount(self): - return float(self.__jsonDict["amount"]) - - def getDateTime(self): - return parse_datetime(self.__jsonDict["datetime"]) - - -class UserTransaction(object): - def __init__(self, jsonDict): - self.__jsonDict = jsonDict - - def getDict(self): - return self.__jsonDict - - def getBTC(self): - return float(self.__jsonDict["btc"]) - - def getBTCUSD(self): - return float(self.__jsonDict["btc_usd"]) - - def getDateTime(self): - return parse_datetime(self.__jsonDict["datetime"]) - - def getFee(self): - return float(self.__jsonDict["fee"]) - - def getId(self): - return int(self.__jsonDict["id"]) - - def getOrderId(self): - return int(self.__jsonDict["order_id"]) - - def getUSD(self): - return float(self.__jsonDict["usd"]) - - -class HTTPClient(object): - USER_AGENT = "PyAlgoTrade" - REQUEST_TIMEOUT = 30 - - class UserTransactionType: - MARKET_TRADE = 2 - - def __init__(self, clientId, key, secret): - self.__clientId = clientId - self.__key = key - self.__secret = secret - self.__nonce = NonceGenerator() - self.__lock = threading.Lock() - - def _buildQuery(self, params): - # Build the signature. - nonce = self.__nonce.getNext() - message = "%d%s%s" % (nonce, self.__clientId, self.__key) - signature = hmac.new(self.__secret, msg=message, digestmod=hashlib.sha256).hexdigest().upper() - - # Headers - headers = {} - headers["User-Agent"] = HTTPClient.USER_AGENT - - # POST data. - data = {} - data.update(params) - data["key"] = self.__key - data["signature"] = signature - data["nonce"] = nonce - - return (data, headers) - - def _post(self, url, params): - common.logger.debug("POST to %s with params %s" % (url, str(params))) - - # Serialize access to nonce generation and http requests to avoid - # sending them in the wrong order. - with self.__lock: - data, headers = self._buildQuery(params) - response = requests.post(url, headers=headers, data=data, timeout=HTTPClient.REQUEST_TIMEOUT) - response.raise_for_status() - - jsonResponse = response.json() - - # Check for errors. - if isinstance(jsonResponse, dict): - error = jsonResponse.get("error") - if error is not None: - raise Exception(error) - - return jsonResponse - - def getAccountBalance(self): - url = "https://www.bitstamp.net/api/balance/" - jsonResponse = self._post(url, {}) - return AccountBalance(jsonResponse) - - def getOpenOrders(self): - url = "https://www.bitstamp.net/api/open_orders/" - jsonResponse = self._post(url, {}) - return [Order(json_open_order) for json_open_order in jsonResponse] - - def cancelOrder(self, orderId): - url = "https://www.bitstamp.net/api/cancel_order/" - params = {"id": orderId} - jsonResponse = self._post(url, params) - if jsonResponse != True: - raise Exception("Failed to cancel order") - - def buyLimit(self, limitPrice, quantity): - url = "https://www.bitstamp.net/api/buy/" - - # Rounding price to avoid 'Ensure that there are no more than 2 decimal places' - # error. - price = round(limitPrice, 2) - # Rounding amount to avoid 'Ensure that there are no more than 8 decimal places' - # error. - amount = round(quantity, 8) - - params = { - "price": price, - "amount": amount - } - jsonResponse = self._post(url, params) - return Order(jsonResponse) - - def sellLimit(self, limitPrice, quantity): - url = "https://www.bitstamp.net/api/sell/" - - # Rounding price to avoid 'Ensure that there are no more than 2 decimal places' - # error. - price = round(limitPrice, 2) - # Rounding amount to avoid 'Ensure that there are no more than 8 decimal places' - # error. - amount = round(quantity, 8) - - params = { - "price": price, - "amount": amount - } - jsonResponse = self._post(url, params) - return Order(jsonResponse) - - def getUserTransactions(self, transactionType=None): - url = "https://www.bitstamp.net/api/user_transactions/" - jsonResponse = self._post(url, {}) - if transactionType is not None: - jsonUserTransactions = filter( - lambda jsonUserTransaction: jsonUserTransaction["type"] == transactionType, jsonResponse - ) - else: - jsonUserTransactions = jsonResponse - return [UserTransaction(jsonUserTransaction) for jsonUserTransaction in jsonUserTransactions] diff --git a/pyalgotrade/alpaca/livefeed.py b/pyalgotrade/alpaca/livefeed.py index deeb88975..f12107e94 100644 --- a/pyalgotrade/alpaca/livefeed.py +++ b/pyalgotrade/alpaca/livefeed.py @@ -25,12 +25,16 @@ import time import queue import threading +import asyncio +from alpaca_trade_api.entity import Quote import zmq import json +import pandas as pd from pyalgotrade import feed from pyalgotrade import dataseries +from pyalgotrade import bar, quote, trade from pyalgotrade.dataseries import bards, tradeds, quoteds from pyalgotrade.alpaca import common @@ -40,10 +44,10 @@ class LiveFeed(threading.Thread): """A thread that takes incoming messages from Alapaca and publishes them on ZMQ sockets. """ - def __init__(self, publishing_address, api_key_id = None, api_secret_key = None): + def __init__(self, publishing_address = 'tcp://*:34567', api_key_id = None, api_secret_key = None): # Create sockets - self.__zmq_context = zmq.Context.Instance() + self.__zmq_context = zmq.Context() # for publishing data from the websocket # (from Alpaca's stream.TradingStream and stream.DataStream) @@ -51,7 +55,7 @@ def __init__(self, publishing_address, api_key_id = None, api_secret_key = None) self.__socket = self.__zmq_context.socket(zmq.PUB) self.__socket.bind(self._publishing_address) common.logger.info( - 'Live feed publishing data at {self._publishing_address}' + f'Live feed publishing data at {self._publishing_address}' ) # threading stuff @@ -59,70 +63,87 @@ def __init__(self, publishing_address, api_key_id = None, api_secret_key = None) self.__stopped = False # make connection - stream = common.make_connection( - connection_type = 'stream', api_key_id = api_key_id, api_secret_key = api_secret_key) + # use the non-paper connection to get data + self.stream = common.make_connection( + connection_type = 'stream', api_key_id = api_key_id, api_secret_key = api_secret_key, live = True) + - # Threading stuff def start(self): - pass - - def run(self): - pass + self.stream.run() def stop(self): - self.__stop = True - + self.__zmq_context.term() + + + # for dynamic subscription + # See https://github.com/alpacahq/alpaca-trade-api-python/blob/master/examples/websockets/dynamic_subscription_example.py + # def consumer_thread(): + # try: + # # make sure we have an event loop, if not create a new one + # loop = asyncio.get_event_loop() + # loop.set_debug(True) + # except RuntimeError: + # asyncio.set_event_loop(asyncio.new_event_loop()) + + # Functions to handle incoming messages - def publish(self, topic, messages): + async def publish(self, topic, messages): for message in messages: self.__socket.send_multipart([topic.encode(), message]) def publish_with_topic(self, topic): - return lambda message: self.publish(topic.encode(), message) + async def publish(message): + msg = json.dumps(message, default = common.json_serializer) + self.__socket.send_multipart([topic.encode(), msg.encode()]) + return publish + # async def print_stuff(message): + # print(pd.to_datetime(message['t'].to_unix_nano())) + # return print_stuff # subscribe to real time data def subscribe_trade_updates(self): self.stream.subscribe_trade_updates(self.publish_with_topic('BROKER')) def subscribe_trades(self, *symbols, handler_cancel_errors = None, handler_corrections = None): - self.stream.subscribe_trades(self.publish_with_topic('TRADES'), symbols, handler_cancel_errors, handler_corrections) + self.stream.subscribe_trades(self.publish_with_topic('TRADES'), *symbols, handler_cancel_errors, handler_corrections) def subscribe_quotes(self, *symbols): - self.stream.subscribe_quotes(self.publish_with_topic('QUOTES'), symbols) + self.stream.subscribe_quotes(self.publish_with_topic('QUOTES'), *symbols) def subscribe_bars(self, *symbols): - self.stream.subscribe_bars(self.publish_with_topic('BARS'), symbols) + self.stream.subscribe_bars(self.publish_with_topic('BARS'), *symbols) def subscribe_dailiy_bars(self, *symbols): - self.stream.subscribe_daily_bars(self.publish_with_topic('BARS'), symbols) + self.stream.subscribe_daily_bars(self.publish_with_topic('BARS'), *symbols) def subscribe_statuses(self, *symbols): - self.stream.subscribe_statuses(self.publish_with_topic('STATUSES'), symbols) + self.stream.subscribe_statuses(self.publish_with_topic('STATUSES'), *symbols) def subscribe_lulds(self, *symbols): - self.stream.subscribe_lulds(self.publish_with_topic('LULDS'), symbols) + self.stream.subscribe_lulds(self.publish_with_topic('LULDS'), *symbols) def subscribe_crypto_trades(self, *symbols): - self.stream.subscribe_crypto_trades(self.publish_with_topic('TRADES'), symbols) + # self.stream.subscribe_crypto_trades(self.publish_with_topic('TRADES'), symbols) + self.stream.subscribe_crypto_trades(self.publish_with_topic('TRADES'), *symbols) def subscribe_crypto_quotes(self, *symbols): - self.stream.subscribe_crypto_quotes(self.publish_with_topic('QUOTES'), symbols) + self.stream.subscribe_crypto_quotes(self.publish_with_topic('QUOTES'), *symbols) def subscribe_crypto_bars(self, *symbols): - self.stream.subscribe_crypto_bars(self.publish_with_topic('BARS'), symbols) + self.stream.subscribe_crypto_bars(self.publish_with_topic('BARS'), *symbols) def subscribe_crypto_daily_bars(self, *symbols): - self.stream.subscribe_crypto_daily_bars(self.publish_with_topic('BARS'), symbols) + self.stream.subscribe_crypto_daily_bars(self.publish_with_topic('BARS'), *symbols) class EventQueuer(threading.Thread): """A thread that checks a ZMQ SUB socket for streaming data. """ POLL_FREQUENCY = 0.5 - def __init__(self, liveFeedAddress, topic): + def __init__(self, liveFeedAddress = 'tcp://localhost:34567', topic = ''): super(EventQueuer, self).__init__() - self.__zmq_context = zmq.Context.Instance() + self.__zmq_context = zmq.Context() self.__event_socket = self.__zmq_context.socket(zmq.SUB) self.__event_socket.setsockopt(zmq.SUBSCRIBE, str(topic).encode()) self.__event_socket.connect(liveFeedAddress) @@ -131,10 +152,11 @@ def __init__(self, liveFeedAddress, topic): def _getNewEvent(self): try: - update = self.__data_socket.recv_multipart(zmq.NOBLOCK) - update = update['data'] + topic, update = self.__event_socket.recv_multipart(zmq.NOBLOCK) + update = update.decode() + update = json.loads(update, object_hook = common.json_deserializer) return update - except zmq.ZMQERROR as exc: + except zmq.ZMQError as exc: if exc.errno == zmq.EAGAIN: # nothing to get return @@ -147,15 +169,15 @@ def getQueue(self): def start(self): if (newEvent:= self._getNewEvent()): self.__queue.put(newEvent) - common.logger.info('New Event: {newEvent}') + # common.logger.info(f'New Event: {newEvent}') super(EventQueuer, self).start() def run(self): while not self.__stop: try: - if (newEvent:= self._getNewTrade()): + if (newEvent:= self._getNewEvent()): self.__queue.put(newEvent) - common.logger.info('New Event: {newEvent}') + # common.logger.info(f'New Event: {newEvent}') else: time.sleep(EventQueuer.POLL_FREQUENCY) except Exception as e: @@ -184,7 +206,7 @@ def reset(self): self.__lastData = {} super(BaseLiveDataFeed, self).reset() - @abc.abstrctmethod + @abc.abstractmethod def createDataSeries(self, key, maxLen): pass @@ -206,7 +228,7 @@ def getNextValues(self): def getNextData(self): ret = None try: - ret = self.__dataQueuer.getQueue(block = True, timeout = BaseLiveDataFeed.QUEUE_TIMEOUT) + ret = self.__dataQueuer.getQueue().get(block = True, timeout = BaseLiveDataFeed.QUEUE_TIMEOUT) return ret except: return False @@ -236,3 +258,49 @@ def __init__(self, liveFeedAddress, maxLen = None): def createDataSeries(self, key, maxLen): ret = quoteds.QuoteDataSeries(maxLen) return ret + +def fromAlpacaData(alpacaDataPoint): + # see https://alpaca.markets/docs/api-documentation/api-v2/market-data/alpaca-data-api-v2/real-time/ + # bars + if alpacaDataPoint['T'] == 'b': + data = bar.Bar( + dateTime = alpacaDataPoint.get('t'), + open_ = alpacaDataPoint.get('o'), + high = alpacaDataPoint.get('h'), + low = alpacaDataPoint.get('l'), + close = alpacaDataPoint.get('c'), + volume = alpacaDataPoint.get('v'), + adjClose = alpacaDataPoint.get('c'), + frequency = None, + extra = {'symbol': alpacaDataPoint.get('S')} + ) + # trades + elif alpacaDataPoint.get('T') == 't': + data = trade.Trade( + dateTime = alpacaDataPoint.get('t'), + tradeId = alpacaDataPoint.get('i'), + price = alpacaDataPoint.get('p'), + size = alpacaDataPoint.get('s'), + exchange = alpacaDataPoint.get('x'), + condition = alpacaDataPoint.get('c'), + tape = alpacaDataPoint.get('z'), + takerSide = alpacaDataPoint.get('tks'), + extra = {'symbol': alpacaDataPoint.get('S')} + ) + # quotes + elif alpacaDataPoint.get('T') == 'q': + data = quote.Quote( + dateTime = alpacaDataPoint.get('t'), + askExchange = alpacaDataPoint.get('ax'), + askPrice = alpacaDataPoint.get('ap'), + askSize = alpacaDataPoint.get('as'), + bidExchange = alpacaDataPoint.get('bx'), + bidPrice = alpacaDataPoint.get('bp'), + bidSize = alpacaDataPoint.get('bs'), + condition = alpacaDataPoint.get('c'), + tape = alpacaDataPoint.get('z'), + extra = {'symbol': alpacaDataPoint.get('S')} + ) + + return data + diff --git a/pyalgotrade/alpaca/wsclient.py b/pyalgotrade/alpaca/wsclient.py deleted file mode 100644 index 1951d0a0e..000000000 --- a/pyalgotrade/alpaca/wsclient.py +++ /dev/null @@ -1,240 +0,0 @@ -# PyAlgoTrade -# -# Copyright 2011-2018 Gabriel Martin Becedillas Ruiz -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -.. moduleauthor:: Gabriel Martin Becedillas Ruiz -""" - -import datetime - -from six.moves import queue - -from pyalgotrade.websocket import pusher -from pyalgotrade.websocket import client -from pyalgotrade.bitstamp import common - - -def get_current_datetime(): - return datetime.datetime.now() - -# Bitstamp protocol reference: https://www.bitstamp.net/websocket/ - - -class Trade(pusher.Event): - """A trade event. - - # TODO: use broker specific formats - """ - - def __init__(self, dateTime, eventDict): - super(Trade, self).__init__(eventDict, True) - self.__dateTime = dateTime - - def getDateTime(self): - """Returns the :class:`datetime.datetime` when this event was received.""" - # return self.__dateTime - # TODO: use event timestamp - raise NotImplementedError() - - def getId(self): - """Returns the trade id.""" - return self.getData()["id"] - - def getPrice(self): - """Returns the trade price.""" - return self.getData()["price"] - - def getAmount(self): - """Returns the trade amount.""" - return self.getData()["amount"] - - def isBuy(self): - """Returns True if the trade was a buy.""" - return self.getData()["type"] == 0 - - def isSell(self): - """Returns True if the trade was a sell.""" - return self.getData()["type"] == 1 - - -class OrderBookUpdate(pusher.Event): - """An order book update event. - - # TODO: use broker specific formats - """ - - def __init__(self, dateTime, eventDict): - super(OrderBookUpdate, self).__init__(eventDict, True) - self.__dateTime = dateTime - - def getDateTime(self): - """Returns the :class:`datetime.datetime` when this event was received.""" - return self.__dateTime - - def getBidPrices(self): - """Returns a list with the top 20 bid prices.""" - return [float(bid[0]) for bid in self.getData()["bids"]] - - def getBidVolumes(self): - """Returns a list with the top 20 bid volumes.""" - return [float(bid[1]) for bid in self.getData()["bids"]] - - def getAskPrices(self): - """Returns a list with the top 20 ask prices.""" - return [float(ask[0]) for ask in self.getData()["asks"]] - - def getAskVolumes(self): - """Returns a list with the top 20 ask volumes.""" - return [float(ask[1]) for ask in self.getData()["asks"]] - -class Bar(pusher.Event): - """A bar event. - - # TODO: use broker specific formats - """ - - def __init__(self, dateTime, eventDict): - super(OrderBookUpdate, self).__init__(eventDict, True) - self.__dateTime = dateTime - - def getDateTime(self): - """Returns the :class:`datetime.datetime` when this event was received.""" - # return self.__dateTime - raise NotImplementedError() - - def getOpen(self): - raise NotImplementedError() - - def getClose(self): - raise NotImplementedError() - - def getHigh(self): - raise NotImplementedError() - - def getLow(self): - raise NotImplementedError() - - def getVolume(self): - raise NotImplementedError() - - def getFrequency(self): - raise NotImplementedError() - - -class WebSocketClient(pusher.WebSocketClient): - """ - This websocket client class is designed to be running in a separate thread and for that reason - events are pushed into a queue. - """ - - PUSHER_APP_KEY = "de504dc5763aeef9ff52" - - class Event: - TRADE = 1 - ORDER_BOOK_UPDATE = 2 - CONNECTED = 3 - DISCONNECTED = 4 - BAR = 5 - - def __init__(self, queue): - super(WebSocketClient, self).__init__(WebSocketClient.PUSHER_APP_KEY, protocol = 5) - self.__queue = queue - - def onMessage(self, msg): - # If we can't handle the message, forward it to Pusher WebSocketClient. - event = msg.get("event") - if event == "trade": - self.onTrade(Trade(get_current_datetime(), msg)) - elif event == "data" and msg.get("channel") == "order_book": - self.onOrderBookUpdate(OrderBookUpdate(get_current_datetime(), msg)) - else: - super(WebSocketClient, self).onMessage(msg) - - ###################################################################### - # WebSocketClientBase events. - - def onClosed(self, code, reason): - common.logger.info("Closed. Code: %s. Reason: %s." % (code, reason)) - self.__queue.put((WebSocketClient.Event.DISCONNECTED, None)) - - def onDisconnectionDetected(self): - common.logger.warning("Disconnection detected.") - try: - self.stopClient() - except Exception as e: - common.logger.error("Error stopping websocket client: %s." % (str(e))) - self.__queue.put((WebSocketClient.Event.DISCONNECTED, None)) - - ###################################################################### - # Pusher specific events. - - def onConnectionEstablished(self, event): - common.logger.info("Connection established.") - self.__queue.put((WebSocketClient.Event.CONNECTED, None)) - - channels = ["live_trades", "order_book"] - common.logger.info("Subscribing to channels %s." % channels) - for channel in channels: - self.subscribeChannel(channel) - - def onError(self, event): - common.logger.error("Error: %s" % (event)) - - def onUnknownEvent(self, event): - common.logger.warning("Unknown event: %s" % (event)) - - ###################################################################### - # Broker specific - - def onTrade(self, trade): - self.__queue.put((WebSocketClient.Event.TRADE, trade)) - - def onOrderBookUpdate(self, orderBookUpdate): - self.__queue.put((WebSocketClient.Event.ORDER_BOOK_UPDATE, orderBookUpdate)) - - -class WebSocketClientThread(client.WebSocketClientThreadBase): - """ - This thread class is responsible for running a WebSocketClient. - """ - - def __init__(self): - super(WebSocketClientThread, self).__init__() - self.__queue = queue.Queue() - self.__wsClient = None - - def getQueue(self): - return self.__queue - - def run(self): - super(WebSocketClientThread, self).run() - - # We create the WebSocketClient right in the thread, instead of doing so in the constructor, - # because it has thread affinity. - try: - self.__wsClient = WebSocketClient(self.__queue) - self.__wsClient.connect() - self.__wsClient.startClient() - except Exception: - common.logger.exception("Failed to connect: %s") - - def stop(self): - try: - if self.__wsClient is not None: - common.logger.info("Stopping websocket client.") - self.__wsClient.stopClient() - except Exception as e: - common.logger.error("Error stopping websocket client: %s." % (str(e))) diff --git a/pyalgotrade/quote.py b/pyalgotrade/quote.py index 97a302ed0..fb444364f 100644 --- a/pyalgotrade/quote.py +++ b/pyalgotrade/quote.py @@ -7,6 +7,7 @@ Frequency = bar.Frequency class Quote(object): + # TODO: add __str__ and __repr__ method? # Optimization to reduce memory footprint. __slots__ = ( '__dateTime', @@ -16,7 +17,7 @@ class Quote(object): '__bid_exchange', '__bid_price', '__bid_size', - '__quote_condition', + '__condition', '__tape', '__extra' ) @@ -24,7 +25,7 @@ class Quote(object): def __init__(self, dateTime, askExchange, askPrice, askSize, bidExchange, bidPrice, bidSize, - quoteCondition, tape, extra = {}): + condition, tape, extra = {}): self.__dateTime = dateTime self.__askExchange = askExchange @@ -33,7 +34,7 @@ def __init__(self, dateTime, self.__bidExchange = bidExchange self.__bidPrice = bidPrice self.__bidSize = bidSize - self.__quoteCondition = quoteCondition + self.__condition = condition self.__tape = tape self.__extra = extra @@ -45,7 +46,7 @@ def __setstate__(self, state): self.__bidExchange, self.__bidPrice, self.__bidSize, - self.__quoteCondition, + self.__condition, self.__tape, self.__extra ) = state @@ -59,7 +60,7 @@ def __getstate__(self): self.__bidExchange, self.__bidPrice, self.__bidSize, - self.__quoteCondition, + self.__condition, self.__tape, self.__extra ) @@ -88,8 +89,8 @@ def getBidPrice(self): def getBidSize(self): return self.__bidSize - def getQuoteCondition(self): - return self.__quoteCondition + def getCondition(self): + return self.__condition def getTape(self): return self.__tape diff --git a/pyalgotrade/trade.py b/pyalgotrade/trade.py index e4f9f138b..595181f36 100644 --- a/pyalgotrade/trade.py +++ b/pyalgotrade/trade.py @@ -7,23 +7,58 @@ Frequency = bar.Frequency class Trade(object): - + # TODO: add __str__ and __repr__ method? # Optimization to reduce memory footprint. - __slots__ = ('__dateTime', '__tradeId', '__price', '__size', '__isBuy') + __slots__ = ( + '__dateTime', + '__tradeId', + '__price', + '__size', + '__exchange', + '__condition', + '__tape', + '__takerSide', + '__extra' + ) + + def __init__(self, dateTime, + tradeId, price, size, + exchange, condition, tape, takerSide, + extra = {}): - def __init__(self, dateTime, tradeId, price, size, isBuy, extra = {}): self.__dateTime = dateTime self.__tradeId = tradeId self.__price = price self.__size = size - self.__isBuy = isBuy + self.__exchange = exchange + self.__condition = condition + self.__tape = tape + self.__takerSide = takerSide self.__extra = extra def __setstate__(self, state): - (self.__dateTime, self.__tradeId, self.__price, self.__amount, self.__extra) = state + (self.__dateTime, + self.__tradeId, + self.__price, + self.__size, + self.__exchange, + self.__condition, + self.__tape, + self.__takerSide, + self.__extra) = state def __getstate__(self): - return (self.__dateTime, self.__tradeId, self.__price, self.__amount, self.__extra) + return ( + self.__dateTime, + self.__tradeId, + self.__price, + self.__size, + self.__exchange, + self.__condition, + self.__tape, + self.__takerSide, + self.__extra + ) def getFrequency(self): return Frequency.TRADE @@ -40,8 +75,17 @@ def getPrice(self): def getSize(self): return self.__size - def getIsBuy(self): - return self.__isBuy + def getExchange(self): + return self.__exchange + + def getCondition(self): + return self.__condition + + def getTape(self): + return self.__tape + + def getTakerSide(self): + return self.__takerSide def getExtraColumns(self): return self.__extra