Skip to content

Commit

Permalink
Merge pull request #41 from tsukuba-websci/nanami/fix/io
Browse files Browse the repository at this point in the history
fix & update: 出力形式の変更
  • Loading branch information
krmr73 authored Jun 14, 2023
2 parents ca129df + f0828bf commit 5d0a997
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 10 deletions.
2 changes: 2 additions & 0 deletions ga/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ results/twitter/**/*

results/synthetic/**/*
!results/synthetic/.gitkeep

results/grid_search/**/*
6 changes: 4 additions & 2 deletions ga/ga.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,12 @@ def dump_population(self, population: list, generation: int, fitness: list) -> N
fitness (list): 適応度
"""
fp = f"{self.archives_dir}/{str(generation).zfill(8)}.csv"
population_fitness = sorted(list(zip(population, fitness)), key=lambda x: -x[1])

with open(fp, "w") as f:
writer = csv.writer(f)
writer.writerow(["rho", "nu", "recentness", "frequency", "objective"])
for individual, fit in zip(population, fitness):
writer.writerow(["rho", "nu", "recentness", "frequency", "distance"])
for individual, fit in population_fitness:
writer.writerow([individual[0], individual[1], individual[2], individual[3], -1 * fit])

def plot():
Expand Down
2 changes: 1 addition & 1 deletion ga/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def main():
# Set Up GridSearch
mutation_rate_iter = [0.01, 0.02, 0.03, 0.04, 0.05]
cross_rate_iter = [0.8, 0.85, 0.9, 0.95]
population_size_iter = [20, 40, 60, 80, 100]
population_size_iter = [10, 20, 30, 40, 50]
num_generations = 100
output_dir = f"./results/grid_search/{target_data}"
os.makedirs(output_dir, exist_ok=True)
Expand Down
8 changes: 5 additions & 3 deletions ga/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,19 @@ def parse_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
return args


def export_individual(distance: float, individual: list, fpath: str) -> None:
def export_individual(
distance: float, individual: list, population_size: int, mutation_rate: float, cross_rate: float, fpath: str
) -> None:
"""個体をCSVファイルに出力する.
Args:
distance (float): ターゲットとの距離
individual (list): 個体を表すタプル.(rho, nu, recentness, frequency)の順
fpath (str): 出力先のパス
"""
header = ["rho", "nu", "recentness", "frequency", "objective"]
header = ["rho", "nu", "recentness", "frequency", "objective", "populasion_size", "mutation_rate", "cross_rate"]
objective = distance
row = [*individual, objective]
row = [*individual, objective, population_size, mutation_rate, cross_rate]
with open(fpath, "w") as f:
writer = csv.writer(f)
writer.writerow(header)
Expand Down
2 changes: 1 addition & 1 deletion ga/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def main():
archive_dir=os.path.join(output_base_dir, "archives"),
)

export_individual(min_distance, best_individual, output_fp)
export_individual(min_distance, best_individual, population_size, mutation_rate, cross_rate, output_fp)

logging.info(f"Finihsed GA. Result is dumped to {target_data}")

Expand Down
10 changes: 7 additions & 3 deletions qd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(

def run(self):
self.jl_main, self.thread_num = JuliaInitializer().initialize()
history2vec_ = History2Vec(self.jl_main, self.thread_num)

archive: Union[CVTArchive, None] = None
if os.path.exists(f"{self.archives_dir_path}/archive.pkl"):
Expand Down Expand Up @@ -105,7 +106,7 @@ def run(self):
with Pool(self.thread_num) as pool:
histories = pool.map(run_model, params_list)

history_vecs = History2Vec(self.jl_main, self.thread_num).history2vec_parallel(histories, 1000)
history_vecs = history2vec_.history2vec_parallel(histories, 1000)

bcs = self.history2bd.run(histories)

Expand All @@ -131,7 +132,9 @@ def run(self):
},
inplace=True,
)
df = df[["rho", "nu", "recentness", "frequency", "objective"]].sort_values(by="objective", ascending=False)
df["objective"] = -df["objective"]
df.rename(columns={"objective": "distance"}, inplace=True)
df = df[["rho", "nu", "recentness", "frequency", "distance"]].sort_values(by="distance", ascending=True)
df.to_csv(f"{self.archives_dir_path}/{iter:0>8}.csv", index=False)

if iter % 25 == 0:
Expand All @@ -141,7 +144,8 @@ def run(self):
assert archive.stats is not None, "archive.stats is None!"
print(f" - Max Score: {archive.stats.obj_max}")
# save best result as csv
df.head(1).to_csv(f"{self.result_dir_path}/best.csv", index=False)
if not os.path.exists(f"{self.result_dir_path}/best.csv"):
df.head(1).to_csv(f"{self.result_dir_path}/best.csv", index=False)


if __name__ == "__main__":
Expand Down
File renamed without changes.

0 comments on commit 5d0a997

Please sign in to comment.