Skip to content

Commit

Permalink
Merge pull request #7 from tsukuba-websci/junya/update/ga-io
Browse files Browse the repository at this point in the history
Feature: GAの出力形式の修正
  • Loading branch information
krmr73 authored Jun 12, 2023
2 parents 475f54e + 3329527 commit 5b76a33
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 76 deletions.
4 changes: 2 additions & 2 deletions ga/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ results/aps/**/*
results/twitter/**/*
!results/twitter/.gitkeep

results/synthetic/**/*
!results/synthetic/.gitkeep
results/synthetic_target/**/*
!results/synthetic_target/.gitkeep
48 changes: 37 additions & 11 deletions ga/ga.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import csv
import logging
from typing import Any, List, Tuple

Expand All @@ -12,27 +13,34 @@ class GA:
def __init__(
self,
population_size: int,
rate: float,
mutation_rate: float,
cross_rate: float,
target: History2VecResult,
target_data: str,
num_generations: int,
jl_main: Any,
thread_num: int,
archive_dir: str,
min_val: float = -1.0,
max_val: float = 1.0,
debug: bool = True,
is_grid_search: bool = False,
) -> None:
self.population_size = population_size
self.min_val = min_val
self.max_val = max_val
self.rate = rate
self.mutation_rate = mutation_rate
self.cross_rate = cross_rate
self.num_generations = 500
self.num_generations = num_generations

self.target = target
self.target_data = target_data
self.jl_main = jl_main
self.thread_num = thread_num
self.histories = [[] for _ in range(self.population_size)]
self.archives_dir = archive_dir
self.debug = debug
self.is_grid_search = is_grid_search

def tovec(self, history: List[Tuple[int, int]], interval_num: int) -> History2VecResult:
"""相互やり取りの履歴を10個の指標に変換する.
Expand Down Expand Up @@ -72,7 +80,7 @@ def selection(self, population: list, fitness: list) -> list:
"""ルーレット選択.適応度に比例した確率で個体を選択し,親個体にする.この親個体を用いて交叉を行う.
Args:
population (list): 各個体のパラメータ (rho, nu, recentness, friendship) のリスト
population (list): 各個体のパラメータ (rho, nu, recentness, frequency) のリスト
fitness (list): 各個体の適応度
Returns:
Expand All @@ -91,8 +99,8 @@ def crossover(self, parents1: list, parents2: list, children: list) -> list:
"""交叉.親のうちランダムに選んだものを交叉させる.
Args:
parents1 (list): 親1 (rho, nu, recentness, friendship) のリスト
parents2 (list): 親2 (rho, nu, recentness, friendship) のリスト
parents1 (list): 親1 (rho, nu, recentness, frequency) のリスト
parents2 (list): 親2 (rho, nu, recentness, frequency) のリスト
children (list): 子のリスト
Returns:
Expand All @@ -115,25 +123,40 @@ def mutation(self, children: list) -> list:
children (list): 子のリスト
"""
for i in range(self.population_size):
if np.random.rand() < self.rate:
if np.random.rand() < self.mutation_rate:
idx = np.random.randint(4)
children[i][idx] = np.random.uniform(low=self.min_val, high=self.max_val)
return children

def dump_population(self, population: list, generation: int, fitness: list) -> None:
"""個体群をファイルに出力する.
Args:
population (list): 個体群
generation (int): 世代数
fitness (list): 適応度
"""
fp = f"{self.archives_dir}/{str(generation).zfill(8)}.csv"
with open(fp, "w") as f:
writer = csv.writer(f)
writer.writerow(["rho", "nu", "recentness", "frequency", "objective"])
for individual, fit in zip(population, fitness):
writer.writerow([individual[0], individual[1], individual[2], individual[3], -1 * fit])

def plot():
return

def run_init(self) -> list:
"""GAの初期個体群を生成する.
Returns:
population (list): 初期個体群 (rho, nu, recentness, friendship) のリスト
population (list): 初期個体群 (rho, nu, recentness, frequency) のリスト
"""
rho = np.random.uniform(low=0, high=30, size=self.population_size)
nu = np.random.uniform(low=0, high=30, size=self.population_size)
recentness = np.random.uniform(low=self.min_val, high=self.max_val, size=self.population_size)
friendship = np.random.uniform(low=self.min_val, high=self.max_val, size=self.population_size)
population = np.array([rho, nu, recentness, friendship]).T
frequency = np.random.uniform(low=self.min_val, high=self.max_val, size=self.population_size)
population = np.array([rho, nu, recentness, frequency]).T
return population

def run(self) -> Tuple[float, History2VecResult, list]:
Expand Down Expand Up @@ -167,7 +190,7 @@ def run(self) -> Tuple[float, History2VecResult, list]:
rho=population[i][0],
nu=population[i][1],
recentness=population[i][2],
friendship=population[i][3],
frequency=population[i][3],
steps=20000,
)
self.histories[i] = run_model(params)
Expand Down Expand Up @@ -202,6 +225,9 @@ def run(self) -> Tuple[float, History2VecResult, list]:
message = f"Generation {generation}: Best fitness = {best_fitness}, Best params = {best_params}, 10Metrics = {metrics}"
logging.info(message)

# 個体群の出力
self.dump_population(population, generation, fitness)

# 適応度の最小値,ターゲット,最適解,10個の指標を返す
arg = np.argmax(fitness)
return (
Expand Down
2 changes: 1 addition & 1 deletion ga/history2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class Params:
rho: float
nu: float
recentness: float
friendship: float
frequency: float
steps: int


Expand Down
43 changes: 40 additions & 3 deletions ga/io_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse
import csv
import json
import os


def parse_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
Expand All @@ -11,20 +13,43 @@ def parse_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
Returns:
argparse.Namespace: コマンドライン引数のパース結果
"""

parser.add_argument("population_size", type=int, help="個体数")
parser.add_argument("rate", type=float, help="突然変異率")
parser.add_argument("mutation_rate", type=float, help="突然変異率")
parser.add_argument("cross_rate", type=float, help="交叉率")
parser.add_argument(
"target_data",
type=str,
choices=["twitter", "aps", "synthetic_fitting_target"],
choices=["twitter", "aps", "synthetic"],
help="ターゲットデータ",
)
parser.add_argument("rho", type=int, nargs="?", default=None, help="rho")
parser.add_argument("nu", type=int, nargs="?", default=None, help="nu")
parser.add_argument("s", type=str, nargs="?", default=None, choices=["SSW", "WSW"], help="strategy")

parser.add_argument("-p", "--prod", action="store_true", default=False, help="本番実行用フラグ.出力先を変更する.")
parser.add_argument("-f", "--force", action="store_true", default=False, help="既存のファイルを上書きする.")
args = parser.parse_args()
return args


def export_individual(distance: float, individual: list, fpath: str) -> None:
"""個体をCSVファイルに出力する.
Args:
distance (float): ターゲットとの距離
individual (list): 個体を表すタプル.(rho, nu, recentness, frequency)の順
fpath (str): 出力先のパス
"""
header = ["rho", "nu", "recentness", "frequency", "objective"]
objective = distance
row = [*individual, objective]
with open(fpath, "w") as f:
writer = csv.writer(f)
writer.writerow(header)
writer.writerow(row)


def dump_json(result: tuple, fpath: str) -> None:
"""GAの結果をJSONファイルに出力する.
Expand All @@ -38,7 +63,7 @@ def dump_json(result: tuple, fpath: str) -> None:
"rho": result[2][0],
"nu": result[2][1],
"recentness": result[2][2],
"friendship": result[2][3],
"frequency": result[2][3],
},
"target": {},
"result": {},
Expand All @@ -50,6 +75,18 @@ def dump_json(result: tuple, fpath: str) -> None:
json.dump(res, open(fpath, "w"), indent=4)


def pass_run(force, fpath) -> bool:
"""GAの実行をスキップするかどうかを判定する.forceがFalseかつすでに結果が存在する場合にTrueを返す.
Args:
force (bool): 既存のファイルを上書きするかどうか
fpath (str): 出力先のパス
"""
if force:
return False
return os.path.exists(fpath)


def validate(population_size, rate, cross_rate) -> None:
"""GAのパラメータのバリデーション.
Expand Down
Loading

0 comments on commit 5b76a33

Please sign in to comment.