Skip to content

Commit

Permalink
Merge pull request #2 from VishwamAI/fix-workflow-issues-pr
Browse files Browse the repository at this point in the history
Fix workflow issues pr
  • Loading branch information
kasinadhsarma authored Aug 12, 2024
2 parents a055181 + 9e73c42 commit eb2ce21
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 41 deletions.
45 changes: 20 additions & 25 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3

- name: Set up Python 3.8
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: 3.8

- name: Cache dependencies
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
Expand All @@ -31,38 +31,33 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Install flake8
run: pip install flake8

- name: Run linting
run: |
pip install flake8
flake8 .
run: flake8 .

- name: Run tests
run: |
pip install pytest pytest-cov
pytest --cov=your_module_name_here # Replace with the actual module name
pytest --cov=src
- name: Upload coverage results
if: success()
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
name: coverage-report
path: coverage.xml

lint:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2

- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8

- name: Install dependencies
- name: Debug information
run: |
python -m pip install --upgrade pip
pip install flake8
- name: Run linting
run: flake8 .
echo "Python version:"
python --version
echo "Pip version:"
pip --version
echo "Installed packages:"
pip list
echo "Current directory:"
pwd
echo "Directory contents:"
ls -R
26 changes: 18 additions & 8 deletions src/agent_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from fetchai.ledger.api.token import TokenTxFactory

import numpy as np
from trading_environment import TradingEnvironment
from moving_average_crossover_strategy import crossover_strategy
import jax.numpy as jnp
import tensorflow as tf
from src.trading_environment import TradingEnvironment
from moving_average_crossover_strategy import crossover_strategy_jax, crossover_strategy_tf

class TradingAgent:
def __init__(self, entity: Entity, ledger_api: LedgerApi, contract: Contract):
Expand All @@ -20,14 +22,21 @@ def make_decision(self, observation):
# Extract relevant information from the observation
balance, shares_held, short_ma, long_ma = observation

# Use the crossover strategy to generate a signal
prices = np.array([short_ma, long_ma]) # Use the moving averages as a proxy for prices
signal = crossover_strategy(prices, self.short_window, self.long_window)[-1]
# Use both JAX and TensorFlow crossover strategies to generate signals
price = self.environment.prices[self.environment.current_step]
prices_jax = jnp.array([price])
prices_tf = tf.constant([price])

# Convert signal to action
if signal == 1 and balance > 0:
signal_jax = crossover_strategy_jax(prices_jax, self.short_window, self.long_window)[-1]
signal_tf = crossover_strategy_tf(prices_tf, self.short_window, self.long_window)[-1].numpy()

# Combine signals (e.g., take the average)
combined_signal = (signal_jax + signal_tf) / 2

# Convert combined signal to action
if combined_signal > 0.5 and balance > 0:
return 1 # Buy
elif signal == -1 and shares_held > 0:
elif combined_signal < -0.5 and shares_held > 0:
return 2 # Sell
else:
return 0 # Hold
Expand Down Expand Up @@ -80,5 +89,6 @@ def main():
agent = TradingAgent(entity, ledger_api, contract)
agent.run()


if __name__ == "__main__":
main()
11 changes: 6 additions & 5 deletions src/alpha_vantage_data_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ def fetch_daily_stock_data(symbol):
'symbol': symbol,
'apikey': API_KEY
}

try:
response = requests.get(BASE_URL, params=params)
response.raise_for_status() # Raise an exception for bad status codes
data = response.json()

if 'Error Message' in data:
logging.error(f"Error fetching data for {symbol}: {data['Error Message']}")
return None

return data
except requests.RequestException as e:
logging.error(f"Request failed for {symbol}: {str(e)}")
Expand All @@ -51,7 +51,7 @@ def save_data_to_file(data, symbol):

def main():
symbols = ['AAPL', 'GOOGL', 'MSFT'] # Example stock symbols

for symbol in symbols:
logging.info(f"Fetching data for {symbol}")
data = fetch_daily_stock_data(symbol)
Expand All @@ -60,8 +60,9 @@ def main():
else:
logging.warning(f"No data fetched for {symbol}")


if __name__ == "__main__":
if not API_KEY:
logging.error("Alpha Vantage API key not found. Please set the ALPHA_VANTAGE_API_KEY environment variable.")
else:
main()
main()
1 change: 1 addition & 0 deletions src/moving_average_crossover_stratagy.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,6 @@ def main():

print(f"Final cumulative return: {cumulative_returns[-1]:.2f}")


if __name__ == "__main__":
main()
9 changes: 6 additions & 3 deletions src/trading_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import pandas as pd
from gym import spaces
from moving_average_crossover_strategy import moving_average, load_data
from moving_average_crossover_strategy import moving_average_jax, load_data

class TradingEnvironment(gym.Env):
def __init__(self, csv_file_path, initial_balance=10000, transaction_fee=0.001):
Expand All @@ -29,11 +29,14 @@ def __init__(self, csv_file_path, initial_balance=10000, transaction_fee=0.001):
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32)

def _get_observation(self):
short_ma = moving_average_jax(self.prices[:self.current_step + 1], self.short_window)[-1]
long_ma = moving_average_jax(self.prices[:self.current_step + 1], self.long_window)[-1]

obs = np.array([
self.balance,
self.shares_held,
moving_average(self.prices[:self.current_step+1], self.short_window)[-1],
moving_average(self.prices[:self.current_step+1], self.long_window)[-1]
short_ma if not np.isnan(short_ma) else self.prices[self.current_step],
long_ma if not np.isnan(long_ma) else self.prices[self.current_step]
])
return obs

Expand Down
1 change: 1 addition & 0 deletions src/yahoo_finance_data_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,6 @@ def main():
else:
logging.warning(f"No data fetched for {symbol}")


if __name__ == "__main__":
main()

0 comments on commit eb2ce21

Please sign in to comment.