Moved flatfile database to flatfile.py

Added mongodb connection settings to config
Updated config format
This commit is contained in:
socks 2021-03-08 13:34:51 +00:00
parent a0aad6bc97
commit a893cc1c04
4 changed files with 55 additions and 29 deletions

View file

@ -1,5 +1,13 @@
# fmt:off # fmt:off
config = {
"database": "stonks", # MongoDB is recommended but not required
"apikey": "Your discord API key" use_mongodb = False
}
# Set filename for flat file database
database = "stonks.txt"
# For URI info see https://docs.mongodb.com/manual/reference/connection-string/
mongodb_uri = "mongodb://localhost:27017"
mongodb_database = "stonks"
api_key = "Your Discord API Key"

View file

@ -2,8 +2,10 @@ import json
import asyncio import asyncio
import yfinance as yf import yfinance as yf
from decimal import Decimal from decimal import Decimal
from config import config import config as cfg
db = config['database']
db = cfg.database
def stock_exists(ticker): def stock_exists(ticker):
try: try:
@ -12,10 +14,12 @@ def stock_exists(ticker):
except KeyError: except KeyError:
return False return False
def write_file(data): def write_file(data):
with open(db, 'w') as f: with open(db, "w") as f:
f.write(data) f.write(data)
def get_stocks(user): def get_stocks(user):
with open(db) as f: with open(db) as f:
data = json.loads(f.read()) data = json.loads(f.read())
@ -23,22 +27,26 @@ def get_stocks(user):
data = dict(data) data = dict(data)
if user in list(data.keys()): if user in list(data.keys()):
user_data = data[user] user_data = data[user]
return user_data['portfolio'] return user_data["portfolio"]
else: else:
data[user] = {"portfolio": {}, "watchlist": {}} data[user] = {"portfolio": {}, "watchlist": {}}
write_file(json.dumps(data)) write_file(json.dumps(data))
return get_stocks(user) return get_stocks(user)
def add_stock(user, stock, amount): def add_stock(user, stock, amount):
with open(db) as f: with open(db) as f:
data = dict(json.loads(f.read())) data = dict(json.loads(f.read()))
if user in data.keys(): if user in data.keys():
if stock in data[user]['portfolio'].keys(): if stock in data[user]["portfolio"].keys():
data[user]['portfolio'][stock] = str(Decimal(amount) + Decimal( data[user]['portfolio'][stock] )) data[user]["portfolio"][stock] = str(
Decimal(amount) + Decimal(data[user]["portfolio"][stock])
)
elif stock_exists(stock): elif stock_exists(stock):
data[user]['portfolio'][stock] = str(Decimal(amount)) data[user]["portfolio"][stock] = str(Decimal(amount))
else: return False else:
return False
write_file(json.dumps(data)) write_file(json.dumps(data))
return True return True
else: else:
@ -52,11 +60,11 @@ def delete_stock(user, stock):
if user in data.keys(): if user in data.keys():
# user exists # user exists
portfolio = data[user]['portfolio'] portfolio = data[user]["portfolio"]
if stock in portfolio.keys(): if stock in portfolio.keys():
# stock exists # stock exists
del portfolio[stock] del portfolio[stock]
data[user]['portfolio'] = portfolio data[user]["portfolio"] = portfolio
write_file(json.dumps(data)) write_file(json.dumps(data))
return get_stocks(user) return get_stocks(user)
@ -69,7 +77,8 @@ def get_watchlist(user):
get_stocks(user) get_stocks(user)
return get_watchlist(user) return get_watchlist(user)
return data[user]['watchlist'] return data[user]["watchlist"]
def watch(user, stock, est_price=0): def watch(user, stock, est_price=0):
with open(db) as f: with open(db) as f:
@ -82,9 +91,9 @@ def watch(user, stock, est_price=0):
if not stock_exists(stock): if not stock_exists(stock):
return False return False
watchlist = data[user]['watchlist'] watchlist = data[user]["watchlist"]
watchlist[stock] = str(Decimal(est_price)) watchlist[stock] = str(Decimal(est_price))
data[user]['watchlist'] = watchlist data[user]["watchlist"] = watchlist
write_file(json.dumps(data)) write_file(json.dumps(data))
return True return True
@ -98,9 +107,10 @@ def unwatch(user, stock):
get_stocks(user) get_stocks(user)
return return
if not stock_exists: return if not stock_exists:
return
watchlist = data[user]['watchlist'] watchlist = data[user]["watchlist"]
del watchlist[stock] del watchlist[stock]
data[user]['watchlist'] = watchlist data[user]["watchlist"] = watchlist
write_file(json.dumps(data)) write_file(json.dumps(data))

22
main.py
View file

@ -1,17 +1,24 @@
import discord import discord
import database, table, yfi import table, yfi
import table2 as t2 import table2 as t2
from discord.ext import commands from discord.ext import commands
from config import config import config as cfg
from decimal import Decimal from decimal import Decimal
import typing import typing
print(cfg.use_mongodb)
if cfg.use_mongodb:
import mongo as database
print("[*] Using MongoDB Database")
else:
import flatfile as database
print("[*] Using FlatFile Database")
intents = discord.Intents.default() intents = discord.Intents.default()
bot = commands.Bot( bot = commands.Bot(command_prefix="$", intents=intents)
command_prefix="$",
intents=intents,
)
@bot.command() @bot.command()
@ -62,6 +69,7 @@ async def portfolio(ctx):
[current_portfolio_value[stock] for stock in current_portfolio_value] [current_portfolio_value[stock] for stock in current_portfolio_value]
) )
# TODO: Raises ZeroDivisionError if user has no stocks
total_delta = 100 * ( total_delta = 100 * (
(current_portfolio_total_value / yesterday_portfolio_total_value) - 1 (current_portfolio_total_value / yesterday_portfolio_total_value) - 1
) )
@ -183,4 +191,4 @@ async def watchlist(ctx, user: typing.Optional[discord.Member]):
image.close() image.close()
bot.run(config["apikey"]) bot.run(cfg.api_key)

View file

@ -1,10 +1,10 @@
from mongoengine import * from mongoengine import *
from config import config import config as cfg
from bson import Decimal128 from bson import Decimal128
from decimal import Decimal from decimal import Decimal
import yfi import yfi
connect(config["database"]) connect(cfg.mongodb_database, "default", host=cfg.mongodb_uri)
class Portfolio(Document): class Portfolio(Document):