from mongoengine import * import config as cfg from bson import Decimal128 from decimal import Decimal import yfi connect(cfg.mongodb_database, "default", host=cfg.mongodb_uri) class Portfolio(Document): user = StringField(required=True) stocks = DictField(default={}) class Watchlist(Document): user = StringField(required=True) stocks = DictField(default={}) def encode_key(key): return key.replace(".", "\\u002e") def decode_key(key): return key.replace("\\u002e", ".") def user_check(user): # Will create empty portfolio and watchlist for new users if not Portfolio.objects(user=user): portfolio = Portfolio(user=user) portfolio.save() if not Watchlist.objects(user=user): watchlist = Watchlist(user=user) watchlist.save() # Portfolio Stuff def get_stocks(user): user_check(user) stocks = Portfolio.objects.get(user=user).stocks decoded_stocks = {} for stock, amount in stocks.items(): decoded_stocks[decodeKey(stock)] = amount.to_decimal() return decoded_stocks def add_stock(user, stock, amount): user_check(user) amount = Decimal(amount) portfolio = Portfolio.objects.get(user=user) if not (yfi.stock_exists(stock)): # Stock ticker does not exist on yahoo finance return False stock = encodeKey(stock) if stock in portfolio.stocks.keys(): portfolio.stocks[stock] = Decimal128( amount + portfolio.stocks[stock].to_decimal() ) else: portfolio.stocks[stock] = Decimal128(amount) portfolio.save() return True def delete_stock(user, stock): user_check(user) stock = encodeKey(stock) portfolio = Portfolio.objects.get(user=user) if stock in portfolio.stocks.keys(): del portfolio.stocks[stock] portfolio.save() return get_stocks(user) # Watchlist Stuff def get_watchlist(user): user_check(user) wlist = Watchlist.objects.get(user=user).stocks decoded_stocks = {} for stock, val in wlist.items(): decoded_stocks[decodeKey(stock)] = val.to_decimal() return decoded_stocks def watch(user, stock, est_price="0"): user_check(user) watchlist = Watchlist.objects.get(user=user) if not (yfi.stock_exists(stock)): # Stock ticker does not exist on yahoo finance return False stock = encodeKey(stock) watchlist.stocks[stock] = Decimal128(Decimal(est_price)) watchlist.save() return True def unwatch(user, stock): user_check(user) stock = stock.replace(".", "\\u002e") watchlist = Watchlist.objects.get(user=user) if stock in watchlist.stocks.keys(): del watchlist.stocks[stock] watchlist.save() return True