Skip to content

Commit

Permalink
Merge pull request #3 from VishwamAI/update-ci-workflow
Browse files Browse the repository at this point in the history
Update ci workflow
  • Loading branch information
kasinadhsarma authored Aug 13, 2024
2 parents afc3bac + e7f88cf commit 00a07cb
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 236 deletions.
21 changes: 11 additions & 10 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,18 @@ jobs:
- name: Run linting
run: flake8 .

- name: Run tests
run: |
pip install pytest pytest-cov
pytest --cov=src
# Tests have been removed as per user request
# - name: Run tests
# run: |
# pip install pytest pytest-cov
# pytest --cov=src

- name: Upload coverage results
if: success()
uses: actions/upload-artifact@v3
with:
name: coverage-report
path: coverage.xml
# - name: Upload coverage results
# if: success()
# uses: actions/upload-artifact@v3
# with:
# name: coverage-report
# path: coverage.xml

- name: Debug information
run: |
Expand Down
14 changes: 7 additions & 7 deletions src/agent_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jax.numpy as jnp
import tensorflow as tf
from src.trading_environment import TradingEnvironment
from moving_average_crossover_strategy import crossover_strategy_jax, crossover_strategy_tf
from moving_average_crossover_strategy import moving_average

class TradingAgent:
def __init__(self, entity: Entity, ledger_api: LedgerApi, contract: Contract):
Expand All @@ -22,16 +22,16 @@ def make_decision(self, observation):
# Extract relevant information from the observation
balance, shares_held, short_ma, long_ma = observation

# Use both JAX and TensorFlow crossover strategies to generate signals
# Use JAX-based moving average crossover strategy to generate signal
price = self.environment.prices[self.environment.current_step]
prices_jax = jnp.array([price])
prices_tf = tf.constant([price])

signal_jax = crossover_strategy_jax(prices_jax, self.short_window, self.long_window)[-1]
signal_tf = crossover_strategy_tf(prices_tf, self.short_window, self.long_window)[-1].numpy()
short_ma = moving_average(prices_jax, self.short_window)[-1]
long_ma = moving_average(prices_jax, self.long_window)[-1]
signal = jnp.where(short_ma > long_ma, 1, -1)

# Combine signals (e.g., take the average)
combined_signal = (signal_jax + signal_tf) / 2
# Use the signal directly for decision making
combined_signal = signal

# Convert combined signal to action
if combined_signal > 0.5 and balance > 0:
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion src/trading_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import pandas as pd
from gym import spaces
from moving_average_crossover_strategy import moving_average_jax, load_data
from moving_average_crossover_strategy import moving_average as moving_average_jax, load_data

class TradingEnvironment(gym.Env):
def __init__(self, csv_file_path, initial_balance=10000, transaction_fee=0.001):
Expand Down
78 changes: 0 additions & 78 deletions test/test_agent_management.py

This file was deleted.

62 changes: 0 additions & 62 deletions test/test_moving_average_crossover_stratagy.py

This file was deleted.

78 changes: 0 additions & 78 deletions test/test_trading_envronment.py

This file was deleted.

0 comments on commit 00a07cb

Please sign in to comment.