112 lines
2.7 KiB
Python
112 lines
2.7 KiB
Python
from mongoengine import *
|
|
import config as cfg
|
|
from bson import Decimal128
|
|
from decimal import Decimal
|
|
import yfi
|
|
|
|
connect(cfg.mongodb_database, "default", host=cfg.mongodb_uri)
|
|
|
|
|
|
class Portfolio(Document):
|
|
user = StringField(required=True)
|
|
stocks = DictField(default={})
|
|
|
|
|
|
class Watchlist(Document):
|
|
user = StringField(required=True)
|
|
stocks = DictField(default={})
|
|
|
|
|
|
def encode_key(key):
|
|
return key.replace(".", "\\u002e")
|
|
|
|
|
|
def decode_key(key):
|
|
return key.replace("\\u002e", ".")
|
|
|
|
|
|
def user_check(user):
|
|
# Will create empty portfolio and watchlist for new users
|
|
if not Portfolio.objects(user=user):
|
|
portfolio = Portfolio(user=user)
|
|
portfolio.save()
|
|
if not Watchlist.objects(user=user):
|
|
watchlist = Watchlist(user=user)
|
|
watchlist.save()
|
|
|
|
|
|
# Portfolio Stuff
|
|
|
|
|
|
def get_stocks(user):
|
|
user_check(user)
|
|
stocks = Portfolio.objects.get(user=user).stocks
|
|
decoded_stocks = {}
|
|
for stock, amount in stocks.items():
|
|
decoded_stocks[decode_key(stock)] = amount.to_decimal()
|
|
return decoded_stocks
|
|
|
|
|
|
def add_stock(user, stock, amount):
|
|
user_check(user)
|
|
amount = Decimal(amount)
|
|
portfolio = Portfolio.objects.get(user=user)
|
|
|
|
if not (yfi.stock_exists(stock)):
|
|
# Stock ticker does not exist on yahoo finance
|
|
return False
|
|
stock = encode_key(stock)
|
|
if stock in portfolio.stocks.keys():
|
|
portfolio.stocks[stock] = Decimal128(
|
|
amount + portfolio.stocks[stock].to_decimal()
|
|
)
|
|
else:
|
|
portfolio.stocks[stock] = Decimal128(amount)
|
|
portfolio.save()
|
|
return True
|
|
|
|
|
|
def delete_stock(user, stock):
|
|
user_check(user)
|
|
stock = encode_key(stock)
|
|
portfolio = Portfolio.objects.get(user=user)
|
|
if stock in portfolio.stocks.keys():
|
|
del portfolio.stocks[stock]
|
|
portfolio.save()
|
|
return get_stocks(user)
|
|
|
|
|
|
# Watchlist Stuff
|
|
|
|
|
|
def get_watchlist(user):
|
|
user_check(user)
|
|
wlist = Watchlist.objects.get(user=user).stocks
|
|
decoded_stocks = {}
|
|
for stock, val in wlist.items():
|
|
decoded_stocks[decode_key(stock)] = val.to_decimal()
|
|
return decoded_stocks
|
|
|
|
|
|
def watch(user, stock, est_price="0"):
|
|
user_check(user)
|
|
|
|
watchlist = Watchlist.objects.get(user=user)
|
|
|
|
if not (yfi.stock_exists(stock)):
|
|
# Stock ticker does not exist on yahoo finance
|
|
return False
|
|
stock = encode_key(stock)
|
|
watchlist.stocks[stock] = Decimal128(Decimal(est_price))
|
|
watchlist.save()
|
|
return True
|
|
|
|
|
|
def unwatch(user, stock):
|
|
user_check(user)
|
|
stock = stock.replace(".", "\\u002e")
|
|
watchlist = Watchlist.objects.get(user=user)
|
|
if stock in watchlist.stocks.keys():
|
|
del watchlist.stocks[stock]
|
|
watchlist.save()
|
|
return True
|