diff --git a/backend/src/main/java/org/sejongisc/backend/backtest/dto/BacktestRequest.java b/backend/src/main/java/org/sejongisc/backend/backtest/dto/BacktestRequest.java index 68258b9b..b35ab8ad 100644 --- a/backend/src/main/java/org/sejongisc/backend/backtest/dto/BacktestRequest.java +++ b/backend/src/main/java/org/sejongisc/backend/backtest/dto/BacktestRequest.java @@ -35,6 +35,10 @@ public class BacktestRequest { @Schema(description = "백테스트 실행 요청 JSON") private BacktestRunRequest strategy; + @Schema(description = "기본 청산 기간") + private int defaultExitDays; + + // 백테스트 리스트 삭제 @Schema(description = "삭제할 백테스트 실행 리스트") private List backtestRunIds; diff --git a/backend/src/main/java/org/sejongisc/backend/backtest/dto/BacktestResponse.java b/backend/src/main/java/org/sejongisc/backend/backtest/dto/BacktestResponse.java index 6250e9e1..702ba18a 100644 --- a/backend/src/main/java/org/sejongisc/backend/backtest/dto/BacktestResponse.java +++ b/backend/src/main/java/org/sejongisc/backend/backtest/dto/BacktestResponse.java @@ -1,17 +1,12 @@ package org.sejongisc.backend.backtest.dto; -import jakarta.persistence.*; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Getter; import lombok.NoArgsConstructor; -import org.hibernate.annotations.JdbcTypeCode; -import org.hibernate.type.SqlTypes; import org.sejongisc.backend.backtest.entity.BacktestRun; -import org.sejongisc.backend.backtest.entity.BacktestRunMetrics; import org.sejongisc.backend.backtest.entity.BacktestStatus; import org.sejongisc.backend.template.entity.Template; -import org.sejongisc.backend.user.entity.User; import java.time.LocalDate; @@ -21,13 +16,6 @@ @AllArgsConstructor @NoArgsConstructor public class BacktestResponse { - private Long id; - private Template template; - private String title; - private BacktestStatus status; - private String paramsJson; - private LocalDate startDate; - private LocalDate endDate; - - private BacktestRunMetrics backtestRunMetrics; + private BacktestRun backtestRun; + private BacktestRunMetricsResponse backtestRunMetricsResponse; } \ No newline at end of file diff --git a/backend/src/main/java/org/sejongisc/backend/backtest/dto/BacktestRunMetricsResponse.java b/backend/src/main/java/org/sejongisc/backend/backtest/dto/BacktestRunMetricsResponse.java new file mode 100644 index 00000000..5eb3379a --- /dev/null +++ b/backend/src/main/java/org/sejongisc/backend/backtest/dto/BacktestRunMetricsResponse.java @@ -0,0 +1,35 @@ +package org.sejongisc.backend.backtest.dto; + + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.sejongisc.backend.backtest.entity.BacktestRunMetrics; + +import java.math.BigDecimal; + + +@Getter +@NoArgsConstructor +@AllArgsConstructor +@Builder +public class BacktestRunMetricsResponse { + private Long id; + private BigDecimal totalReturn; // 총 수익률 + private BigDecimal maxDrawdown; // 최대 낙폭 + private BigDecimal sharpeRatio; // 샤프 지수 + private BigDecimal avgHoldDays; // 평균 보유 기간 + private int tradesCount; // 총 거래 횟수 + + public static BacktestRunMetricsResponse fromEntity(BacktestRunMetrics backtestRunMetrics) { + return BacktestRunMetricsResponse.builder() + .id(backtestRunMetrics.getId()) + .totalReturn(backtestRunMetrics.getTotalReturn()) + .maxDrawdown(backtestRunMetrics.getMaxDrawdown()) + .sharpeRatio(backtestRunMetrics.getSharpeRatio()) + .avgHoldDays(backtestRunMetrics.getAvgHoldDays()) + .tradesCount(backtestRunMetrics.getTradesCount()) + .build(); + } +} diff --git a/backend/src/main/java/org/sejongisc/backend/backtest/dto/StrategyOperand.java b/backend/src/main/java/org/sejongisc/backend/backtest/dto/StrategyOperand.java index 68c31188..a4c597dc 100644 --- a/backend/src/main/java/org/sejongisc/backend/backtest/dto/StrategyOperand.java +++ b/backend/src/main/java/org/sejongisc/backend/backtest/dto/StrategyOperand.java @@ -4,6 +4,7 @@ import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.Setter; import java.math.BigDecimal; import java.util.Map; @@ -14,7 +15,7 @@ */ @Getter @NoArgsConstructor -@AllArgsConstructor +@Setter public class StrategyOperand { @Schema(description = "항의 타입: \"indicator\", \"price\", \"const\"") @@ -34,4 +35,6 @@ public class StrategyOperand { @Schema(description = "지표의 파라미터 맵 (예: {\"length\": 20})") private Map params; + + //private String transform; // 거래량 관련 필드, 추후 적용 고려 } \ No newline at end of file diff --git a/backend/src/main/java/org/sejongisc/backend/backtest/dto/TradeLog.java b/backend/src/main/java/org/sejongisc/backend/backtest/dto/TradeLog.java new file mode 100644 index 00000000..aadfebd1 --- /dev/null +++ b/backend/src/main/java/org/sejongisc/backend/backtest/dto/TradeLog.java @@ -0,0 +1,19 @@ +package org.sejongisc.backend.backtest.dto; + +import java.math.BigDecimal; +import java.time.LocalDateTime; + +public class TradeLog { + public enum Type { BUY, SELL } + public final Type type; + public final LocalDateTime time; + public final BigDecimal price; + public final BigDecimal shares; + + public TradeLog(Type type, LocalDateTime time, BigDecimal price, BigDecimal shares) { + this.type = type; + this.time = time; + this.price = price; + this.shares = shares; + } +} \ No newline at end of file diff --git a/backend/src/main/java/org/sejongisc/backend/backtest/entity/BacktestRunMetrics.java b/backend/src/main/java/org/sejongisc/backend/backtest/entity/BacktestRunMetrics.java index f45c10e5..72831aed 100644 --- a/backend/src/main/java/org/sejongisc/backend/backtest/entity/BacktestRunMetrics.java +++ b/backend/src/main/java/org/sejongisc/backend/backtest/entity/BacktestRunMetrics.java @@ -27,7 +27,7 @@ public class BacktestRunMetrics { private BigDecimal totalReturn; // 총 수익률 @Column(nullable = false, precision = 12, scale = 6) - private BigDecimal maxDrawdown; // 최대 낙폭 + private BigDecimal maxDrawdown; // 최대 낙폭, 퍼센티지로 계산됨 @Column(nullable = false, precision = 12, scale = 6) private BigDecimal sharpeRatio; // 샤프 지수 diff --git a/backend/src/main/java/org/sejongisc/backend/backtest/repository/BacktestRunRepository.java b/backend/src/main/java/org/sejongisc/backend/backtest/repository/BacktestRunRepository.java index 51276412..95ef0f94 100644 --- a/backend/src/main/java/org/sejongisc/backend/backtest/repository/BacktestRunRepository.java +++ b/backend/src/main/java/org/sejongisc/backend/backtest/repository/BacktestRunRepository.java @@ -7,6 +7,7 @@ import org.springframework.stereotype.Repository; import java.util.List; +import java.util.Optional; import java.util.UUID; @Repository @@ -16,4 +17,10 @@ public interface BacktestRunRepository extends JpaRepository "WHERE t.templateId = :templateTemplateId " + "ORDER BY br.startedAt DESC") List findByTemplate_TemplateIdWithTemplate(@Param("templateTemplateId") UUID templateTemplateId); + + @Query("SELECT br FROM BacktestRun br " + + "LEFT JOIN FETCH br.template t " + // template은 없을 수 있기에 left join + "JOIN FETCH br.user u " + + "WHERE br.id = :backtestRunId ") + Optional findByIdWithMember(@Param("backtestRunId") Long backtestRunId); } diff --git a/backend/src/main/java/org/sejongisc/backend/backtest/service/BacktestService.java b/backend/src/main/java/org/sejongisc/backend/backtest/service/BacktestService.java index 77dfed38..3d04c898 100644 --- a/backend/src/main/java/org/sejongisc/backend/backtest/service/BacktestService.java +++ b/backend/src/main/java/org/sejongisc/backend/backtest/service/BacktestService.java @@ -6,6 +6,7 @@ import lombok.extern.slf4j.Slf4j; import org.sejongisc.backend.backtest.dto.BacktestRequest; import org.sejongisc.backend.backtest.dto.BacktestResponse; +import org.sejongisc.backend.backtest.dto.BacktestRunMetricsResponse; import org.sejongisc.backend.backtest.entity.BacktestRun; import org.sejongisc.backend.backtest.entity.BacktestRunMetrics; import org.sejongisc.backend.backtest.entity.BacktestStatus; @@ -15,12 +16,11 @@ import org.sejongisc.backend.common.exception.ErrorCode; import org.sejongisc.backend.template.entity.Template; import org.sejongisc.backend.template.repository.TemplateRepository; +import org.sejongisc.backend.user.dao.UserRepository; import org.sejongisc.backend.user.entity.User; -import org.springframework.http.HttpStatus; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; -import java.time.LocalDateTime; import java.util.List; import java.util.UUID; @@ -33,36 +33,31 @@ public class BacktestService { private final TemplateRepository templateRepository; private final BacktestingEngine backtestingEngine; private final ObjectMapper objectMapper; - private final EntityManager em; + private final UserRepository userRepository; public BacktestResponse getBacktestStatus(Long backtestRunId, UUID userId) { - // TODO : 백테스트 상태 조회 로직 구현 (진행 중, 완료, 실패 등) + log.info("백테스팅 실행 상태 조회를 시작합니다."); BacktestRun backtestRun = findBacktestRunByIdAndVerifyUser(backtestRunId, userId); return BacktestResponse.builder() - .id(backtestRun.getId()) - .paramsJson(backtestRun.getParamsJson()) - .title(backtestRun.getTitle()) - .status(backtestRun.getStatus()) - .startDate(backtestRun.getStartDate()) - .endDate(backtestRun.getEndDate()) - .template(backtestRun.getTemplate()) + .backtestRun(backtestRun) .build(); } @Transactional public BacktestResponse getBackTestDetails(Long backtestRunId, UUID userId) { - BacktestRunMetrics backtestRunMetrics = backtestRunMetricsRepository.findByBacktestRunId(backtestRunId) - .orElse(null); BacktestRun backtestRun = findBacktestRunByIdAndVerifyUser(backtestRunId, userId); + if (backtestRun.getStatus() != BacktestStatus.COMPLETED) { + return BacktestResponse.builder() + .backtestRun(backtestRun) + .build(); + } + + BacktestRunMetrics backtestRunMetrics = backtestRunMetricsRepository.findByBacktestRunId(backtestRunId) + .orElseThrow(() -> new CustomException(ErrorCode.BACKTEST_METRICS_NOT_FOUND)); + return BacktestResponse.builder() - .id(backtestRun.getId()) - .paramsJson(backtestRun.getParamsJson()) - .title(backtestRun.getTitle()) - .status(backtestRun.getStatus()) - .startDate(backtestRun.getStartDate()) - .endDate(backtestRun.getEndDate()) - .template(backtestRun.getTemplate()) - .backtestRunMetrics(backtestRunMetrics) + .backtestRun(backtestRun) + .backtestRunMetricsResponse(BacktestRunMetricsResponse.fromEntity(backtestRunMetrics)) .build(); } @@ -81,11 +76,11 @@ public void addBacktestTemplate(BacktestRequest request) { } public BacktestResponse runBacktest(BacktestRequest request) { - User userRef = em.getReference(User.class, request.getUserId()); - - Template templateRef = null; + User user = userRepository.findById(request.getUserId()) + .orElseThrow(() -> new CustomException(ErrorCode.USER_NOT_FOUND)); + Template template = null; if (request.getTemplateId() != null) - templateRef = em.getReference(Template.class, request.getTemplateId()); + template = findTemplateByIdAndVerifyUser(request.getTemplateId(), request.getUserId()); String paramsJson; try { @@ -97,8 +92,8 @@ public BacktestResponse runBacktest(BacktestRequest request) { // BacktestRun 엔티티를 "PENDING" 상태로 생성 BacktestRun backtestRun = BacktestRun.builder() - .user(userRef) - .template(templateRef) + .user(user) + .template(template) .title(request.getTitle()) .paramsJson(paramsJson) .startDate(request.getStartDate()) @@ -114,13 +109,7 @@ public BacktestResponse runBacktest(BacktestRequest request) { // 사용자에게 실행 중 응답 반환 return BacktestResponse.builder() - .id(savedRun.getId()) - .paramsJson(savedRun.getParamsJson()) - .title(savedRun.getTitle()) - .status(savedRun.getStatus()) - .startDate(savedRun.getStartDate()) - .endDate(savedRun.getEndDate()) - .template(templateRef) + .backtestRun(savedRun) .build(); } @@ -159,7 +148,7 @@ private Template findTemplateByIdAndVerifyUser(UUID templateId, UUID userId) { } private BacktestRun findBacktestRunByIdAndVerifyUser(Long backtestRunId, UUID userId) { - BacktestRun backtestRun = backtestRunRepository.findById(backtestRunId) + BacktestRun backtestRun = backtestRunRepository.findByIdWithMember(backtestRunId) .orElseThrow(() -> new CustomException(ErrorCode.BACKTEST_NOT_FOUND)); if (!backtestRun.getUser().getUserId().equals(userId)) { diff --git a/backend/src/main/java/org/sejongisc/backend/backtest/service/BacktestingEngine.java b/backend/src/main/java/org/sejongisc/backend/backtest/service/BacktestingEngine.java index 2230cc8d..a5adcb11 100644 --- a/backend/src/main/java/org/sejongisc/backend/backtest/service/BacktestingEngine.java +++ b/backend/src/main/java/org/sejongisc/backend/backtest/service/BacktestingEngine.java @@ -4,6 +4,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.sejongisc.backend.backtest.dto.BacktestRunRequest; +import org.sejongisc.backend.backtest.dto.TradeLog; import org.sejongisc.backend.backtest.entity.BacktestRun; import org.sejongisc.backend.backtest.entity.BacktestRunMetrics; import org.sejongisc.backend.backtest.entity.BacktestStatus; @@ -41,41 +42,39 @@ public class BacktestingEngine { private final ObjectMapper objectMapper; @Async - @Transactional + //@Transactional(propagation = Propagation.REQUIRES_NEW) @Async는 새로운 쓰레드에서 실행되므로 DB 작업 수행 시 주석 제거 필요 public void execute(Long backtestRunId) { log.info("백테스팅 실행이 시작됩니다. 실행 ID : {}", backtestRunId); BacktestRun backtestRun = backtestRunRepository.findById(backtestRunId) .orElseThrow(() -> new CustomException(ErrorCode.BACKTEST_NOT_FOUND)); + // 거래 로그 리스트 초기화 + List tradeLogs = new ArrayList<>(); + try { backtestRun.setStatus(BacktestStatus.RUNNING); backtestRun.setStartedAt(LocalDateTime.now()); backtestRunRepository.save(backtestRun); log.debug("백테스팅 상태 RUNNING 으로 변경됨. ID : {}", backtestRunId); - // 전략(JSON)을 DTO로 파싱 log.debug("paramsJson: {}", backtestRun.getParamsJson()); BacktestRunRequest strategyDto = objectMapper.readValue(backtestRun.getParamsJson(), BacktestRunRequest.class); String ticker = strategyDto.getTicker(); - log.debug("백테스팅 대상 티커: {}", ticker); + log.info("백테스팅 대상 티커: {}", ticker); - // DB에서 가격 데이터 로드 List priceDataList = priceDataRepository.findByTickerAndDateBetweenOrderByDateAsc( ticker, backtestRun.getStartDate(), backtestRun.getEndDate()); - log.debug("가격 데이터 로드 완료. 데이터 개수: {}", priceDataList.size()); + log.info("가격 데이터 로드 완료. 데이터 개수: {}", priceDataList.size()); if (priceDataList.isEmpty()) { throw new CustomException(ErrorCode.PRICE_DATA_NOT_FOUND); } - // Ta4j BarSeries 생성 BarSeries series = ta4jHelper.createBarSeries(priceDataList); Map> indicatorCache = new HashMap<>(); log.debug("BarSeries 생성 완료. 바 개수: {}", series.getBarCount()); - // 매수/매도 룰 생성 Rule buyRule = ta4jHelper.buildCombinedRule(strategyDto.getBuyConditions(), series, indicatorCache); Rule sellRule = ta4jHelper.buildCombinedRule(strategyDto.getSellConditions(), series, indicatorCache); - // 포트폴리오 초기화 BigDecimal initialCapital = strategyDto.getInitialCapital(); BigDecimal cash = initialCapital; BigDecimal shares = BigDecimal.ZERO; @@ -83,60 +82,69 @@ public void execute(Long backtestRunId) { // MDD 및 수익률 추적용 리스트 List dailyPortfolioValue = new ArrayList<>(); + // 일일 수익률 리스트 (샤프 비율 계산에 사용) + List dailyReturns = new ArrayList<>(); + BigDecimal peakValue = initialCapital; BigDecimal maxDrawdown = BigDecimal.ZERO; + BigDecimal previousValue = initialCapital; // 전날 포트폴리오 가치 for (int i = 0; i < series.getBarCount(); i++) { - // 오늘 날짜의 종가 가져오기 (Num -> BigDecimal) - Num numClosePrice = series.getBar(i).getClosePrice(); // Num 객체 반환 + LocalDateTime currentTime = series.getBar(i).getEndTime().toLocalDateTime(); + Num numClosePrice = series.getBar(i).getClosePrice(); BigDecimal currentClosePrice = new BigDecimal(numClosePrice.toString()); - // 전략 평가 boolean shouldBuy = buyRule.isSatisfied(i); boolean shouldSell = sellRule.isSatisfied(i); - // 거래 실행 및 포트폴리오 관리 // "매수" if (shares.compareTo(BigDecimal.ZERO) == 0 && shouldBuy) { - shares = cash.divide(currentClosePrice, 8, RoundingMode.HALF_UP); + BigDecimal buyShares = cash.divide(currentClosePrice, 8, RoundingMode.HALF_UP); + + // 거래 로그 기록 + tradeLogs.add(new TradeLog(TradeLog.Type.BUY, currentTime, currentClosePrice, buyShares)); + + shares = buyShares; cash = BigDecimal.ZERO; tradesCount++; - log.debug("[{}] BUY at {}", series.getBar(i).getEndTime().toLocalDate(), currentClosePrice); - + log.info("[{}] BUY at {}", currentTime.toLocalDate(), currentClosePrice); } // "매도" else if (shares.compareTo(BigDecimal.ZERO) > 0 && shouldSell) { - cash = shares.multiply(currentClosePrice); + BigDecimal tradeShares = shares; // 매도 주식 수 + BigDecimal tradeValue = shares.multiply(currentClosePrice); + + // 거래 로그 기록 + tradeLogs.add(new TradeLog(TradeLog.Type.SELL, currentTime, currentClosePrice, tradeShares)); + + cash = tradeValue; shares = BigDecimal.ZERO; - log.debug("[{}] SELL at {}", series.getBar(i).getEndTime().toLocalDate(), currentClosePrice); + log.info("[{}] SELL at {}", currentTime.toLocalDate(), currentClosePrice); } + // 일일 포트폴리오 가치 계산 BigDecimal currentTotalValue = cash.add(shares.multiply(currentClosePrice)); dailyPortfolioValue.add(currentTotalValue); + + // 일일 수익률 계산 및 MDD 갱신 + if (i > 0) { + BigDecimal dailyReturn = currentTotalValue.subtract(previousValue) + .divide(previousValue, 8, RoundingMode.HALF_UP); + dailyReturns.add(dailyReturn); + } + previousValue = currentTotalValue; + if (currentTotalValue.compareTo(peakValue) > 0) peakValue = currentTotalValue; - BigDecimal drawdown = peakValue.subtract(currentTotalValue).divide(peakValue, 4, RoundingMode.HALF_UP); - // MDD 갱신 + BigDecimal drawdown = peakValue.subtract(currentTotalValue).divide(peakValue, 8, RoundingMode.HALF_UP); if (drawdown.compareTo(maxDrawdown) > 0) maxDrawdown = drawdown; } - // 최종 지표 계산 - BigDecimal finalPortfolioValue = dailyPortfolioValue.getLast(); - // 총수익률 = (최종자산 / 초기자본) - 1 - BigDecimal totalReturnPct = finalPortfolioValue.divide(initialCapital, 4, RoundingMode.HALF_UP) - .subtract(BigDecimal.ONE); - // MDD (백분율로 변환) - BigDecimal maxDrawdownPct = maxDrawdown.multiply(BigDecimal.valueOf(-100)); - - BacktestRunMetrics metrics = BacktestRunMetrics.builder() - .backtestRun(backtestRun) - .totalReturn(totalReturnPct) - .maxDrawdown(maxDrawdownPct) - .sharpeRatio(BigDecimal.ZERO) // TODO: Sharpe 계산 (일일 수익률 표준편차 필요) - .avgHoldDays(BigDecimal.ZERO) // TODO: 평균 보유일 계산 (거래 로그 필요) - .tradesCount(tradesCount) - .build(); + + // 최종 지표 계산 및 저장 + BacktestRunMetrics metrics = calculateMetrics(backtestRun, initialCapital, tradeLogs, dailyPortfolioValue, dailyReturns, maxDrawdown, tradesCount); backtestRunMetricsRepository.save(metrics); backtestRun.setStatus(BacktestStatus.COMPLETED); + } catch (Exception e) { log.error("Backtest execution failed for run ID: {}", backtestRunId, e); backtestRun.setStatus(BacktestStatus.FAILED); @@ -146,4 +154,76 @@ else if (shares.compareTo(BigDecimal.ZERO) > 0 && shouldSell) { backtestRunRepository.save(backtestRun); } } -} + + // ---------------------------------------------------------------------- + // 지표 계산 보조 메서드 + // ---------------------------------------------------------------------- + private BacktestRunMetrics calculateMetrics(BacktestRun backtestRun, BigDecimal initialCapital, + List tradeLogs, List dailyPortfolioValue, + List dailyReturns, BigDecimal maxDrawdown, int tradesCount) { + + BigDecimal finalPortfolioValue = dailyPortfolioValue.getLast(); + BigDecimal totalReturnPct = finalPortfolioValue.divide(initialCapital, 4, RoundingMode.HALF_UP) + .subtract(BigDecimal.ONE); + + BigDecimal maxDrawdownPct = maxDrawdown.multiply(BigDecimal.valueOf(-100)).setScale(4, RoundingMode.HALF_UP); + + BigDecimal sharpeRatio = calculateSharpeRatio(dailyReturns); + BigDecimal avgHoldDays = calculateAvgHoldDays(tradeLogs); + + return BacktestRunMetrics.builder() + .backtestRun(backtestRun) + .totalReturn(totalReturnPct) + .maxDrawdown(maxDrawdownPct) + .sharpeRatio(sharpeRatio) + .avgHoldDays(avgHoldDays) + .tradesCount(tradesCount) + .build(); + } + + private BigDecimal calculateSharpeRatio(List dailyReturns) { + if (dailyReturns.isEmpty()) return BigDecimal.ZERO; + + BigDecimal sum = dailyReturns.stream().reduce(BigDecimal.ZERO, BigDecimal::add); + BigDecimal mean = sum.divide(BigDecimal.valueOf(dailyReturns.size()), 8, RoundingMode.HALF_UP); + + BigDecimal varianceSum = dailyReturns.stream() + .map(r -> r.subtract(mean)) + .map(d -> d.multiply(d)) + .reduce(BigDecimal.ZERO, BigDecimal::add); + + BigDecimal variance = varianceSum.divide(BigDecimal.valueOf(dailyReturns.size()), 8, RoundingMode.HALF_UP); + BigDecimal standardDeviation = BigDecimal.valueOf(Math.sqrt(variance.doubleValue())); + + if (standardDeviation.compareTo(BigDecimal.ZERO) == 0) return BigDecimal.ZERO; + + // 연율화 샤프 비율 (거래일 기준 252일 가정) + BigDecimal annualizationFactor = BigDecimal.valueOf(Math.sqrt(252)); + BigDecimal sharpeRatio = mean.divide(standardDeviation, 8, RoundingMode.HALF_UP).multiply(annualizationFactor); + + return sharpeRatio.setScale(4, RoundingMode.HALF_UP); + } + + private BigDecimal calculateAvgHoldDays(List tradeLogs) { + List holdDurations = new ArrayList<>(); + LocalDateTime currentBuyTime = null; + + for (TradeLog log : tradeLogs) { + if (log.type == TradeLog.Type.BUY) { + currentBuyTime = log.time; + } else if (log.type == TradeLog.Type.SELL && currentBuyTime != null) { + long days = java.time.temporal.ChronoUnit.DAYS.between(currentBuyTime.toLocalDate(), log.time.toLocalDate()); + holdDurations.add(days); + currentBuyTime = null; + } + } + + if (holdDurations.isEmpty()) return BigDecimal.ZERO; + + long totalDays = holdDurations.stream().reduce(0L, Long::sum); + BigDecimal avgHoldDays = BigDecimal.valueOf(totalDays) + .divide(BigDecimal.valueOf(holdDurations.size()), 2, RoundingMode.HALF_UP); + + return avgHoldDays; + } +} \ No newline at end of file diff --git a/backend/src/main/java/org/sejongisc/backend/backtest/service/Ta4jHelperService.java b/backend/src/main/java/org/sejongisc/backend/backtest/service/Ta4jHelperService.java index c3373ad0..fb5495c7 100644 --- a/backend/src/main/java/org/sejongisc/backend/backtest/service/Ta4jHelperService.java +++ b/backend/src/main/java/org/sejongisc/backend/backtest/service/Ta4jHelperService.java @@ -9,18 +9,17 @@ import org.ta4j.core.BaseBarSeries; import org.ta4j.core.Indicator; import org.ta4j.core.Rule; -import org.ta4j.core.indicators.CachedIndicator; // ⭐️ MACD Hist 구현용 +import org.ta4j.core.indicators.CachedIndicator; import org.ta4j.core.indicators.EMAIndicator; import org.ta4j.core.indicators.MACDIndicator; import org.ta4j.core.indicators.RSIIndicator; import org.ta4j.core.indicators.SMAIndicator; import org.ta4j.core.indicators.helpers.*; +import org.ta4j.core.indicators.ATRIndicator; import org.ta4j.core.num.Num; -import org.ta4j.core.rules.*; // IsEqualRule, AndRule, OrRule, OverIndicatorRule 등 +import org.ta4j.core.rules.*; -import java.math.BigDecimal; import java.time.ZoneId; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -28,12 +27,11 @@ @Service @RequiredArgsConstructor public class Ta4jHelperService { - /** * PriceData 리스트를 ta4j의 BarSeries로 변환합니다. */ public BarSeries createBarSeries(List priceDataList) { - // ⭐️ (수정) BarSeries 이름에 Ticker 추가 + // BarSeries 이름에 Ticker 추가 BarSeries series = new BaseBarSeries(priceDataList.get(0).getTicker()); for (PriceData p : priceDataList) { series.addBar( @@ -46,46 +44,36 @@ public BarSeries createBarSeries(List priceDataList) { /** * DTO 조건(List)을 ta4j의 Rule 객체로 빌드합니다. - * "isAbsolute" 로직(✳️무조건 OR ⚪️일반)을 포함합니다. + * "isAbsolute" 로직을 포함합니다. */ - public Rule buildCombinedRule(List conditions, BarSeries series, - Map> indicatorCache) { - + public Rule buildCombinedRule(List conditions, BarSeries series, Map> indicatorCache) { if (series.isEmpty()) { throw new IllegalArgumentException("Cannot build rules on an empty series."); } - - // ⭐️ (수정) "1"과 "0"에 해당하는 Num 객체를 시리즈에서 가져옴 Num sampleNum = series.getBar(0).getClosePrice(); Num one = sampleNum.numOf(1); Num zero = sampleNum.numOf(0); - // ⭐️ (수정) "1"과 "0"에 해당하는 Indicator를 '먼저' 생성 Indicator indicatorOne = new ConstantIndicator<>(series, one); Indicator indicatorZero = new ConstantIndicator<>(series, zero); - // "FalseRule" 대체: "1 == 0" 규칙 (항상 false) Rule falseRule = new IsEqualRule(indicatorOne, indicatorZero); - // "TrueRule" 대체: "1 == 1" 규칙 (항상 true) Rule trueRule = new IsEqualRule(indicatorOne, indicatorOne); if (conditions == null || conditions.isEmpty()) { - return falseRule; // 조건이 없으면 항상 false + return falseRule; } - // 1. ✳️ '무조건' 조건과 ⚪️ '일반' 조건으로 분리 Map> partitioned = conditions.stream() .collect(Collectors.partitioningBy(StrategyCondition::isAbsolute)); - List absoluteConditions = partitioned.get(true); List standardConditions = partitioned.get(false); - Rule absoluteRule; Rule standardRule; - // 2. ✳️ '무조건' 조건들을 OR로 묶음 + // '무조건' 조건들을 OR로 묶음 if (absoluteConditions.isEmpty()) { - absoluteRule = falseRule; // ✳️ 조건 없음 + absoluteRule = falseRule; } else { Rule combinedOrRule = buildSingleRule(absoluteConditions.get(0), series, indicatorCache); for (int i = 1; i < absoluteConditions.size(); i++) { @@ -97,9 +85,9 @@ public Rule buildCombinedRule(List conditions, BarSeries seri absoluteRule = combinedOrRule; } - // 3. ⚪️ '일반' 조건들을 AND로 묶음 + // '일반' 조건들을 AND로 묶음 if (standardConditions.isEmpty()) { - standardRule = falseRule; // ⚪️ 조건 없음 + standardRule = falseRule; } else { Rule combinedAndRule = buildSingleRule(standardConditions.get(0), series, indicatorCache); for (int i = 1; i < standardConditions.size(); i++) { @@ -110,8 +98,6 @@ public Rule buildCombinedRule(List conditions, BarSeries seri } standardRule = combinedAndRule; } - - // 4. 최종 결합: (✳️무조건 OR ⚪️일반) return new OrRule(absoluteRule, standardRule); } @@ -120,58 +106,54 @@ public Rule buildCombinedRule(List conditions, BarSeries seri */ private Rule buildSingleRule(StrategyCondition condition, BarSeries series, Map> indicatorCache) { - Indicator left = resolveOperand(condition.getLeftOperand(), series, indicatorCache); Indicator right = resolveOperand(condition.getRightOperand(), series, indicatorCache); - // ⭐️ (스크린샷 0.13 버전 규칙 기준) - switch (condition.getOperator()) { - case "GT": - return new OverIndicatorRule(left, right); - case "GTE": - return new IsEqualRule(left, right).or(new OverIndicatorRule(left, right)); - case "LT": - return new UnderIndicatorRule(left, right); - case "LTE": - return new IsEqualRule(left, right).or(new UnderIndicatorRule(left, right)); - case "EQ": - return new IsEqualRule(left, right); - case "CROSSES_ABOVE": - return new CrossedUpIndicatorRule(left, right); - case "CROSSES_BELOW": - return new CrossedDownIndicatorRule(left, right); - default: - throw new IllegalArgumentException("Unknown operator: " + condition.getOperator()); - } + return switch (condition.getOperator()) { + case "GT" -> new OverIndicatorRule(left, right); + case "GTE" -> new IsEqualRule(left, right).or(new OverIndicatorRule(left, right)); + case "LT" -> new UnderIndicatorRule(left, right); + case "LTE" -> new IsEqualRule(left, right).or(new UnderIndicatorRule(left, right)); + case "EQ" -> new IsEqualRule(left, right); + case "CROSSES_ABOVE" -> new CrossedUpIndicatorRule(left, right); + case "CROSSES_BELOW" -> new CrossedDownIndicatorRule(left, right); + default -> throw new IllegalArgumentException("Unknown operator: " + condition.getOperator()); + }; } /** - * StrategyOperand DTO를 ta4j Indicator 객체로 "번역" + * StrategyOperand DTO를 ta4j Indicator 객체로 번역 */ private Indicator resolveOperand(StrategyOperand operand, BarSeries series, Map> indicatorCache) { if (operand == null) return null; - + validateOperand(operand); String key = generateIndicatorKey(operand); if (indicatorCache.containsKey(key)) { return indicatorCache.get(key); } - Indicator indicator; - switch (operand.getType()) { - case "price": - indicator = createPriceIndicator(operand.getPriceField(), series); - break; - case "indicator": - indicator = createIndicator(operand, series, indicatorCache); - break; - case "const": + Num indicatorOne = series.getBar(0).getClosePrice().numOf(1); // 1.0에 해당하는 Num + + Indicator indicator = switch (operand.getType()) { + case "price" -> createPriceIndicator(operand.getPriceField(), series); + case "indicator" -> { + Indicator baseIndicator = createIndicator(operand, series, indicatorCache); + // Transform 로직 추가 + //if ("ATR".equals(operand.getIndicatorCode()) && "pctOfPrice".equals(operand.getTransform())) { + // ATR/ClosePrice 비율을 계산하는 Indicator 생성 + //Indicator closePrice = createPriceIndicator("Close", series); + // ATR / ClosePrice = ATR * (1 / ClosePrice) + //yield new MultiplierIndicator(baseIndicator, new DivisionIndicator(new ConstantIndicator<>(series, indicatorOne), closePrice)); + //} + yield baseIndicator; + } + case "const" -> { Num constValue = series.getBar(0).getClosePrice().numOf(operand.getConstantValue()); - indicator = new ConstantIndicator<>(series, constValue); - break; - default: - throw new IllegalArgumentException("Unknown operand type: " + operand.getType()); - } + yield new ConstantIndicator<>(series, constValue); + } + default -> throw new IllegalArgumentException("Unknown operand type: " + operand.getType()); + }; indicatorCache.put(key, indicator); return indicator; @@ -179,19 +161,14 @@ private Indicator resolveOperand(StrategyOperand operand, BarSeries series, // 팩토리 헬퍼 1: 원본 가격 지표 생성 private Indicator createPriceIndicator(String field, BarSeries series) { - switch (field) { - case "Open": - return new OpenPriceIndicator(series); - case "High": - return new HighPriceIndicator(series); - case "Low": - return new LowPriceIndicator(series); - case "Volume": - return new VolumeIndicator(series, 0); - case "Close": - default: - return new ClosePriceIndicator(series); - } + return switch (field) { + case "Open" -> new OpenPriceIndicator(series); + case "High" -> new HighPriceIndicator(series); + case "Low" -> new LowPriceIndicator(series); + case "Volume" -> new VolumeIndicator(series, 0); + case "Close" -> new ClosePriceIndicator(series); + default -> throw new IllegalArgumentException("Unknown priceField: " + field); + }; } // 팩토리 헬퍼 2: 보조 지표 생성 @@ -200,10 +177,7 @@ private Indicator createIndicator(StrategyOperand operand, BarSeries series String code = operand.getIndicatorCode(); Map params = operand.getParams(); - Indicator baseIndicator = resolveOperand( - new StrategyOperand("price", null, null, null, "Close", null), - series, cache - ); + Indicator baseIndicator = createPriceIndicator("Close", series); switch (code) { case "SMA": @@ -221,26 +195,23 @@ private Indicator createIndicator(StrategyOperand operand, BarSeries series int signal = ((Number) params.get("signal")).intValue(); MACDIndicator macd = new MACDIndicator(baseIndicator, fast, slow); - Indicator signalLine = new EMAIndicator(macd, signal); // Signal 라인 생성 - - switch (operand.getOutput()) { - case "macd": - return macd; - case "signal": - return signalLine; - case "hist": - // ⭐️ (변경) MACDHistogramIndicator -> 수동 계산 클래스 - return new ManualMACDHistogramIndicator(macd, signalLine); - default: - return macd; - } - // TODO: ATR, 볼린저 밴드 등 다른 지표 추가... + Indicator signalLine = new EMAIndicator(macd, signal); + + return switch (operand.getOutput()) { + case "macd" -> macd; + case "signal" -> signalLine; + case "hist" -> new ManualMACDHistogramIndicator(macd, signalLine); + default -> macd; + }; + //case "ATR": + //int atrLength = ((Number) params.get("length")).intValue(); + //return new ATRIndicator(series, atrLength); default: throw new IllegalArgumentException("Unknown indicator code: " + code); } } - // Operand DTO로부터 Map의 키를 생성 + // Operand DTO 로부터 Map의 키를 생성 private String generateIndicatorKey(StrategyOperand operand) { if (operand == null) return "null_operand"; switch (operand.getType()) { @@ -257,6 +228,10 @@ private String generateIndicatorKey(StrategyOperand operand) { if (operand.getOutput() != null && !"value".equals(operand.getOutput())) { key += "." + operand.getOutput(); } + // Transform 정보도 Key에 포함 + //if (operand.getTransform() != null) { + // key += "~" + operand.getTransform(); + //} return key; default: return "unknown_operand"; @@ -264,7 +239,7 @@ private String generateIndicatorKey(StrategyOperand operand) { } /** - * ⭐️ (신규) MACD 히스토그램 수동 계산 클래스 + * MACD 히스토그램 수동 계산 클래스 * (MACDIndicator - EMAIndicator(MACDIndicator, signalLength)) */ private static class ManualMACDHistogramIndicator extends CachedIndicator { @@ -272,7 +247,6 @@ private static class ManualMACDHistogramIndicator extends CachedIndicator { private final Indicator signal; public ManualMACDHistogramIndicator(Indicator macd, Indicator signal) { - // 부모 클래스에 BarSeries를 전달해야 함 (macd에서 가져옴) super(macd); this.macd = macd; this.signal = signal; @@ -280,9 +254,36 @@ public ManualMACDHistogramIndicator(Indicator macd, Indicator signal) @Override protected Num calculate(int index) { - // MACD 값 - Signal 값 return macd.getValue(index).minus(signal.getValue(index)); } } -} + private void validateOperand(StrategyOperand operand) { + if (operand.getType() == null) { + throw new IllegalArgumentException("Operand 'type' must not be null."); + } + + switch (operand.getType()) { + case "price": + if (operand.getPriceField() == null) { + throw new IllegalArgumentException("Operand of type 'price' must have a non-null 'priceField'."); + } + break; + case "indicator": + if (operand.getIndicatorCode() == null) { + throw new IllegalArgumentException("Operand of type 'indicator' must have a non-null 'indicatorCode'."); + } + if (operand.getParams() == null) { + throw new IllegalArgumentException("Operand of type 'indicator' must have non-null 'params'."); + } + break; + case "const": + if (operand.getConstantValue() == null) { + throw new IllegalArgumentException("Operand of type 'const' must have a non-null 'constantValue'."); + } + break; + default: + throw new IllegalArgumentException("Unknown operand type: " + operand.getType()); + } + } +} \ No newline at end of file diff --git a/backend/src/main/java/org/sejongisc/backend/point/service/PointHistoryService.java b/backend/src/main/java/org/sejongisc/backend/point/service/PointHistoryService.java index e2c472a5..af992fae 100644 --- a/backend/src/main/java/org/sejongisc/backend/point/service/PointHistoryService.java +++ b/backend/src/main/java/org/sejongisc/backend/point/service/PointHistoryService.java @@ -1,6 +1,5 @@ package org.sejongisc.backend.point.service; -import jakarta.persistence.OptimisticLockException; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.sejongisc.backend.common.exception.CustomException; @@ -14,7 +13,6 @@ import org.sejongisc.backend.user.entity.User; import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.data.domain.PageRequest; -import org.springframework.orm.ObjectOptimisticLockingFailureException; import org.springframework.retry.annotation.Backoff; import org.springframework.retry.annotation.Recover; import org.springframework.retry.annotation.Retryable; diff --git a/backend/src/main/java/org/sejongisc/backend/template/controller/TemplateController.java b/backend/src/main/java/org/sejongisc/backend/template/controller/TemplateController.java index 371a380b..cdeca16a 100644 --- a/backend/src/main/java/org/sejongisc/backend/template/controller/TemplateController.java +++ b/backend/src/main/java/org/sejongisc/backend/template/controller/TemplateController.java @@ -66,6 +66,7 @@ public ResponseEntity createTemplate(@RequestBody TemplateRequ ) public ResponseEntity updateTemplate(@RequestBody TemplateRequest request, @AuthenticationPrincipal CustomUserDetails customUserDetails) { + request.setUserId(customUserDetails.getUserId()); return ResponseEntity.ok(templateService.updateTemplate(request)); } diff --git a/backend/src/main/java/org/sejongisc/backend/template/dto/TemplateRequest.java b/backend/src/main/java/org/sejongisc/backend/template/dto/TemplateRequest.java index e4350c3d..d7a0f881 100644 --- a/backend/src/main/java/org/sejongisc/backend/template/dto/TemplateRequest.java +++ b/backend/src/main/java/org/sejongisc/backend/template/dto/TemplateRequest.java @@ -1,5 +1,6 @@ package org.sejongisc.backend.template.dto; +import io.swagger.v3.oas.annotations.media.Schema; import lombok.Getter; import lombok.Setter; import org.sejongisc.backend.user.entity.User; @@ -11,18 +12,18 @@ @Setter public class TemplateRequest { - //@Schema(hidden = true, description = "유저") - private UUID userId; + @Schema(hidden = true, description = "유저") + private UUID userId; // 유저 ID - //@Schema(hidden = true, description = "템플릿 ID") + @Schema(hidden = true, description = "템플릿 ID") private UUID templateId; // 템플릿 ID - //@Schema(description = "템플릿 제목", defaultValue = "기술주 템플릿") + @Schema(description = "템플릿 제목", defaultValue = "기술주 템플릿") private String title; - //@Schema(description = "템플릿 설명", defaultValue = "기술주 템플릿입니다.") + @Schema(description = "템플릿 설명", defaultValue = "기술주 템플릿입니다.") private String description; - //@Schema(description = "템플릿 공개 여부", defaultValue = "false") + @Schema(description = "템플릿 공개 여부", defaultValue = "false") private Boolean isPublic; } diff --git a/backend/src/main/java/org/sejongisc/backend/template/service/TemplateService.java b/backend/src/main/java/org/sejongisc/backend/template/service/TemplateService.java index bfc99857..66d10e5f 100644 --- a/backend/src/main/java/org/sejongisc/backend/template/service/TemplateService.java +++ b/backend/src/main/java/org/sejongisc/backend/template/service/TemplateService.java @@ -11,6 +11,7 @@ import org.sejongisc.backend.template.dto.TemplateResponse; import org.sejongisc.backend.template.entity.Template; import org.sejongisc.backend.template.repository.TemplateRepository; +import org.sejongisc.backend.user.dao.UserRepository; import org.sejongisc.backend.user.entity.User; import org.springframework.stereotype.Service; @@ -23,7 +24,7 @@ public class TemplateService { private final TemplateRepository templateRepository; private final BacktestRunRepository backtestRunRepository; - private final EntityManager em; + private final UserRepository userRepository; // 유저 ID로 템플릿 목록 조회 public TemplateResponse findAllByUserId(UUID userId) { @@ -44,10 +45,10 @@ public TemplateResponse findById(UUID templateId, UUID userId) { // 템플릿 생성 public TemplateResponse createTemplate(TemplateRequest request) { - // userId 만을 가진 FK 전용 프록시 객체 생성 - User userRef = em.getReference(User.class, request.getUserId()); + User user = userRepository.findById(request.getUserId()) + .orElseThrow(() -> new CustomException(ErrorCode.USER_NOT_FOUND)); - Template template = Template.of(userRef, request.getTitle(), + Template template = Template.of(user, request.getTitle(), request.getDescription(), request.getIsPublic()); templateRepository.save(template); diff --git a/backend/src/test/java/org/sejongisc/backend/backtest/service/BacktestServiceTest.java b/backend/src/test/java/org/sejongisc/backend/backtest/service/BacktestServiceTest.java index c98a0f3a..0043d9c8 100644 --- a/backend/src/test/java/org/sejongisc/backend/backtest/service/BacktestServiceTest.java +++ b/backend/src/test/java/org/sejongisc/backend/backtest/service/BacktestServiceTest.java @@ -119,7 +119,7 @@ void getBackTestDetails_success() { BacktestResponse response = backtestService.getBackTestDetails(1L, userId); - assertThat(response.getBacktestRunMetrics().getSharpeRatio()).isEqualTo(BigDecimal.valueOf(1.5)); + assertThat(response.getBacktestRunMetricsResponse().getSharpeRatio()).isEqualTo(BigDecimal.valueOf(1.5)); } @Test