stock-tracker-discord/mongo.py
2021-03-11 12:09:05 +00:00

112 lines
2.7 KiB
Python

from mongoengine import *
import config as cfg
from bson import Decimal128
from decimal import Decimal
import yfi
connect(cfg.mongodb_database, "default", host=cfg.mongodb_uri)
class Portfolio(Document):
user = StringField(required=True)
stocks = DictField(default={})
class Watchlist(Document):
user = StringField(required=True)
stocks = DictField(default={})
def encode_key(key):
return key.replace(".", "\\u002e")
def decode_key(key):
return key.replace("\\u002e", ".")
def user_check(user):
# Will create empty portfolio and watchlist for new users
if not Portfolio.objects(user=user):
portfolio = Portfolio(user=user)
portfolio.save()
if not Watchlist.objects(user=user):
watchlist = Watchlist(user=user)
watchlist.save()
# Portfolio Stuff
def get_stocks(user):
user_check(user)
stocks = Portfolio.objects.get(user=user).stocks
decoded_stocks = {}
for stock, amount in stocks.items():
decoded_stocks[decode_key(stock)] = amount.to_decimal()
return decoded_stocks
def add_stock(user, stock, amount):
user_check(user)
amount = Decimal(amount)
portfolio = Portfolio.objects.get(user=user)
if not (yfi.stock_exists(stock)):
# Stock ticker does not exist on yahoo finance
return False
stock = encode_key(stock)
if stock in portfolio.stocks.keys():
portfolio.stocks[stock] = Decimal128(
amount + portfolio.stocks[stock].to_decimal()
)
else:
portfolio.stocks[stock] = Decimal128(amount)
portfolio.save()
return True
def delete_stock(user, stock):
user_check(user)
stock = encode_key(stock)
portfolio = Portfolio.objects.get(user=user)
if stock in portfolio.stocks.keys():
del portfolio.stocks[stock]
portfolio.save()
return get_stocks(user)
# Watchlist Stuff
def get_watchlist(user):
user_check(user)
wlist = Watchlist.objects.get(user=user).stocks
decoded_stocks = {}
for stock, val in wlist.items():
decoded_stocks[decode_key(stock)] = val.to_decimal()
return decoded_stocks
def watch(user, stock, est_price="0"):
user_check(user)
watchlist = Watchlist.objects.get(user=user)
if not (yfi.stock_exists(stock)):
# Stock ticker does not exist on yahoo finance
return False
stock = encode_key(stock)
watchlist.stocks[stock] = Decimal128(Decimal(est_price))
watchlist.save()
return True
def unwatch(user, stock):
user_check(user)
stock = stock.replace(".", "\\u002e")
watchlist = Watchlist.objects.get(user=user)
if stock in watchlist.stocks.keys():
del watchlist.stocks[stock]
watchlist.save()
return True