Skip to content

Commit

Permalink
Update test_trading_envronment.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kasinadhsarma authored Aug 12, 2024
1 parent a2150cc commit db4d38c
Showing 1 changed file with 62 additions and 60 deletions.
122 changes: 62 additions & 60 deletions test/test_trading_envronment.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,78 @@
import unittest
from unittest.mock import MagicMock
from unittest.mock import patch, MagicMock
import numpy as np
import jax.numpy as jnp
import tensorflow as tf
import gym
from your_module import TradingEnvironment, moving_average_jax, load_data # Replace 'your_module' with the actual module name

from fetchai.ledger.crypto import Entity
from fetchai.ledger.contract import Contract
from fetchai.ledger.api import LedgerApi
from fetchai.ledger.api.token import TokenTxFactory
from src.trading_environment import TradingEnvironment
from moving_average_crossover_strategy import crossover_strategy_jax, crossover_strategy_tf
from trading_agent import TradingAgent # Assume your code is in trading_agent.py
class TestTradingEnvironment(unittest.TestCase):

class TestTradingAgent(unittest.TestCase):
def setUp(self):
# Mock the Fetch.ai components
self.entity = MagicMock(spec=Entity)
self.ledger_api = MagicMock(spec=LedgerApi)
self.contract = MagicMock(spec=Contract)
self.agent = TradingAgent(self.entity, self.ledger_api, self.contract)
# Mock data
self.mock_prices = np.array([100, 105, 110, 115, 120, 125, 130], dtype=np.float32)

# Mock the load_data function
self.mock_load_data = MagicMock(return_value=self.mock_prices)

# Initialize the environment
self.env = TradingEnvironment(csv_file_path='dummy_path.csv')

# Mock TradingEnvironment
self.agent.environment = MagicMock(spec=TradingEnvironment)
self.agent.environment.prices = np.array([100, 105, 110, 115])
self.agent.environment.current_step = 0
self.agent.environment.balance = 1000
self.agent.environment.shares_held = 0
self.agent.environment.reset.return_value = (self.agent.environment.balance, self.agent.environment.shares_held, 100, 105)
self.agent.environment.step.return_value = (self.agent.environment.balance, 0, True, {})
@patch('your_module.load_data', self.mock_load_data)
def test_initialization(self):
self.assertEqual(self.env.initial_balance, 10000)
self.assertEqual(self.env.transaction_fee, 0.001)
self.assertEqual(self.env.action_space.n, 3)
self.assertEqual(self.env.observation_space.shape, (4,))
self.assertTrue(np.all(self.env.prices == self.mock_prices))

def test_make_decision_buy_signal(self):
# Mock the crossover strategies
crossover_strategy_jax = MagicMock(return_value=jnp.array([1]))
crossover_strategy_tf = MagicMock(return_value=tf.convert_to_tensor([1]))
@patch('your_module.moving_average_jax')
def test_step_buy(self, mock_moving_average):
mock_moving_average.return_value = np.array([100, 105, 110, 115, 120, 125, 130])

obs = self.env.reset()
self.env.step(1) # Buy action

# Test buy decision
action = self.agent.make_decision((1000, 0, 100, 105))
self.assertEqual(action, 1) # Buy action
self.assertEqual(self.env.balance, 10000 - 100 * (1 + self.env.transaction_fee))
self.assertEqual(self.env.shares_held, 10000 // 100)
self.assertEqual(self.env.current_step, 1)

@patch('your_module.moving_average_jax')
def test_step_sell(self, mock_moving_average):
mock_moving_average.return_value = np.array([100, 105, 110, 115, 120, 125, 130])

def test_make_decision_sell_signal(self):
# Mock the crossover strategies
crossover_strategy_jax = MagicMock(return_value=jnp.array([-1]))
crossover_strategy_tf = MagicMock(return_value=tf.convert_to_tensor([-1]))
# Simulate buying first
self.env.step(1) # Buy action
self.env.step(2) # Sell action

# Test sell decision
self.agent.environment.shares_held = 10
action = self.agent.make_decision((1000, 10, 100, 105))
self.assertEqual(action, 2) # Sell action
self.assertEqual(self.env.balance, (10000 // 100) * 100 * (1 - self.env.transaction_fee))
self.assertEqual(self.env.shares_held, 0)
self.assertEqual(self.env.current_step, 2)

def test_calculate_reward(self):
self.env.reset()
self.env.step(1) # Buy action
self.env.step(2) # Sell action
reward = self.env._calculate_reward()

def test_execute_trade_buy(self):
self.agent.environment.balance = 1000
self.agent.environment.prices = np.array([100])
self.agent.execute_trade(1) # Buy action
# Check if transfer was called with the correct parameters
self.ledger_api.sync.assert_called_once()
# Add more specific assertions if needed
expected_reward = (10000 // 100) * 100 * (1 - self.env.transaction_fee) - 10000
self.assertAlmostEqual(reward, expected_reward)

def test_execute_trade_sell(self):
self.agent.environment.shares_held = 10
self.agent.environment.prices = np.array([100])
self.agent.execute_trade(2) # Sell action
# Check if transfer was called with the correct parameters
self.ledger_api.sync.assert_called_once()
# Add more specific assertions if needed
def test_reset(self):
self.env.step(1) # Perform some actions
obs = self.env.reset()

self.assertEqual(self.env.balance, 10000)
self.assertEqual(self.env.shares_held, 0)
self.assertEqual(self.env.current_step, 0)
self.assertEqual(obs[0], 10000)
self.assertEqual(obs[1], 0)

def test_run(self):
# Mock the methods in the TradingEnvironment
self.agent.environment.step.return_value = (1000, 0, False, {}) # Run one step
self.agent.run(num_episodes=1)
self.agent.environment.reset.assert_called()
self.agent.environment.step.assert_called()
# Add more specific assertions if needed
@patch('builtins.print')
def test_render(self, mock_print):
self.env.reset()
self.env.render()
mock_print.assert_called_with(
f'Step: 0\nBalance: 10000\nShares held: 0\nCurrent price: 100\nTotal value: 10000'
)

if __name__ == "__main__":
unittest.main()

0 comments on commit db4d38c

Please sign in to comment.