diff --git a/config.example.py b/config.example.py index 1a57a03..4bb7de5 100644 --- a/config.example.py +++ b/config.example.py @@ -1,5 +1,13 @@ # fmt:off -config = { - "database": "stonks", - "apikey": "Your discord API key" -} + +# MongoDB is recommended but not required +use_mongodb = False + +# Set filename for flat file database +database = "stonks.txt" + +# For URI info see https://docs.mongodb.com/manual/reference/connection-string/ +mongodb_uri = "mongodb://localhost:27017" +mongodb_database = "stonks" + +api_key = "Your Discord API Key" diff --git a/database.py b/flatfile.py similarity index 72% rename from database.py rename to flatfile.py index 046e463..f5dd726 100644 --- a/database.py +++ b/flatfile.py @@ -2,8 +2,10 @@ import json import asyncio import yfinance as yf from decimal import Decimal -from config import config -db = config['database'] +import config as cfg + +db = cfg.database + def stock_exists(ticker): try: @@ -12,10 +14,12 @@ def stock_exists(ticker): except KeyError: return False + def write_file(data): - with open(db, 'w') as f: + with open(db, "w") as f: f.write(data) + def get_stocks(user): with open(db) as f: data = json.loads(f.read()) @@ -23,22 +27,26 @@ def get_stocks(user): data = dict(data) if user in list(data.keys()): user_data = data[user] - return user_data['portfolio'] + return user_data["portfolio"] else: data[user] = {"portfolio": {}, "watchlist": {}} write_file(json.dumps(data)) return get_stocks(user) + def add_stock(user, stock, amount): with open(db) as f: data = dict(json.loads(f.read())) if user in data.keys(): - if stock in data[user]['portfolio'].keys(): - data[user]['portfolio'][stock] = str(Decimal(amount) + Decimal( data[user]['portfolio'][stock] )) + if stock in data[user]["portfolio"].keys(): + data[user]["portfolio"][stock] = str( + Decimal(amount) + Decimal(data[user]["portfolio"][stock]) + ) elif stock_exists(stock): - data[user]['portfolio'][stock] = str(Decimal(amount)) - else: return False + data[user]["portfolio"][stock] = str(Decimal(amount)) + else: + return False write_file(json.dumps(data)) return True else: @@ -52,11 +60,11 @@ def delete_stock(user, stock): if user in data.keys(): # user exists - portfolio = data[user]['portfolio'] + portfolio = data[user]["portfolio"] if stock in portfolio.keys(): # stock exists del portfolio[stock] - data[user]['portfolio'] = portfolio + data[user]["portfolio"] = portfolio write_file(json.dumps(data)) return get_stocks(user) @@ -69,7 +77,8 @@ def get_watchlist(user): get_stocks(user) return get_watchlist(user) - return data[user]['watchlist'] + return data[user]["watchlist"] + def watch(user, stock, est_price=0): with open(db) as f: @@ -82,9 +91,9 @@ def watch(user, stock, est_price=0): if not stock_exists(stock): return False - watchlist = data[user]['watchlist'] + watchlist = data[user]["watchlist"] watchlist[stock] = str(Decimal(est_price)) - data[user]['watchlist'] = watchlist + data[user]["watchlist"] = watchlist write_file(json.dumps(data)) return True @@ -98,9 +107,10 @@ def unwatch(user, stock): get_stocks(user) return - if not stock_exists: return + if not stock_exists: + return - watchlist = data[user]['watchlist'] + watchlist = data[user]["watchlist"] del watchlist[stock] - data[user]['watchlist'] = watchlist + data[user]["watchlist"] = watchlist write_file(json.dumps(data)) diff --git a/main.py b/main.py index 72d3711..644500c 100644 --- a/main.py +++ b/main.py @@ -1,17 +1,24 @@ import discord -import database, table, yfi +import table, yfi import table2 as t2 from discord.ext import commands -from config import config +import config as cfg from decimal import Decimal import typing +print(cfg.use_mongodb) +if cfg.use_mongodb: + import mongo as database + + print("[*] Using MongoDB Database") +else: + import flatfile as database + + print("[*] Using FlatFile Database") + intents = discord.Intents.default() -bot = commands.Bot( - command_prefix="$", - intents=intents, -) +bot = commands.Bot(command_prefix="$", intents=intents) @bot.command() @@ -62,6 +69,7 @@ async def portfolio(ctx): [current_portfolio_value[stock] for stock in current_portfolio_value] ) + # TODO: Raises ZeroDivisionError if user has no stocks total_delta = 100 * ( (current_portfolio_total_value / yesterday_portfolio_total_value) - 1 ) @@ -183,4 +191,4 @@ async def watchlist(ctx, user: typing.Optional[discord.Member]): image.close() -bot.run(config["apikey"]) +bot.run(cfg.api_key) diff --git a/mongo.py b/mongo.py index d677c8c..cbd623d 100644 --- a/mongo.py +++ b/mongo.py @@ -1,10 +1,10 @@ from mongoengine import * -from config import config +import config as cfg from bson import Decimal128 from decimal import Decimal import yfi -connect(config["database"]) +connect(cfg.mongodb_database, "default", host=cfg.mongodb_uri) class Portfolio(Document):