Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
dd7a3e7
feat: learned mapping from latent to real actions
emergenz Jul 19, 2025
78d89be
added npy to array-record
maharajamihir Aug 11, 2025
4e0846b
adding pngs to array records as well
maharajamihir Aug 11, 2025
92be24f
Update input_pipeline/preprocess/npy_to_array_records.py
maharajamihir Aug 18, 2025
6264ea7
Update input_pipeline/preprocess/pngs_to_array_records.py
maharajamihir Aug 19, 2025
eeb24a7
Apply suggestions from code review
maharajamihir Aug 19, 2025
b3148a5
modify coinrun data generation to output array-records, remove unused…
maharajamihir Sep 4, 2025
43bdbba
revert refactoring change
maharajamihir Sep 4, 2025
c2533c8
standardized metadata
maharajamihir Sep 4, 2025
e44cb26
new return format for edge case
maharajamihir Sep 4, 2025
84e90b5
converted frames to unit8 during coinrun datagen
maharajamihir Sep 4, 2025
1ff514f
added non-video file warning, omitted default envname
maharajamihir Sep 4, 2025
e9fbc59
created val and test loss data generation script
maharajamihir Sep 4, 2025
773ab44
path handling fix
maharajamihir Sep 4, 2025
2f8025e
added val loss logic to tokenizer
maharajamihir Sep 4, 2025
3bb22df
added val loss logic to tokenizer
maharajamihir Sep 4, 2025
d53848c
added val loss to lam
maharajamihir Sep 5, 2025
4b964f1
val loss implemented for dynamics model
maharajamihir Sep 5, 2025
d2cdec9
fix: rename dataloader_state to train_dataloader state for proper che…
maharajamihir Sep 5, 2025
70b00b6
Merge branch 'validation-loss' into action-mapper
maharajamihir Sep 8, 2025
8b8ee3c
added action dataloading and nnx migration
maharajamihir Sep 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 80 additions & 34 deletions generate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,95 @@
from pathlib import Path

from gym3 import types_np
import os
import numpy as np
from procgen import ProcgenGym3Env
import tyro
import pickle
import json
from array_record.python.array_record_module import ArrayRecordWriter



@dataclass
class Args:
num_episodes: int = 10000
num_episodes_train: int = 10000
num_episodes_val: int = 500
num_episodes_test: int = 500
output_dir: str = "data/coinrun_episodes"
min_episode_length: int = 50
seed: int = 0


args = tyro.cli(Args)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

# --- Generate episodes ---
i = 0
metadata = []
while i < args.num_episodes:
seed = np.random.randint(0, 10000)
env = ProcgenGym3Env(num=1, env_name="coinrun", start_level=seed)
dataseq = []

# --- Run episode ---
for j in range(1000):
env.act(types_np.sample(env.ac_space, bshape=(env.num,)))
rew, obs, first = env.observe()
dataseq.append(obs["rgb"])
if first:
break

# --- Save episode ---
if len(dataseq) >= args.min_episode_length:
episode_data = np.concatenate(dataseq, axis=0)
episode_path = output_dir / f"episode_{i}.npy"
np.save(episode_path, episode_data.astype(np.uint8))
metadata.append({"path": str(episode_path), "length": len(dataseq)})
print(f"Episode {i} completed, length: {len(dataseq)}")
i += 1
else:
print(f"Episode too short ({len(dataseq)}), resampling...")

# --- Save metadata ---
np.save(output_dir / "metadata.npy", metadata)
print(f"Dataset generated with {len(metadata)} valid episodes")

def generate_episodes(num_episodes, split):
i = 0
episode_metadata = []
while i < num_episodes:
seed = np.random.randint(0, 10000)
env = ProcgenGym3Env(num=1, env_name="coinrun", start_level=seed)
observations_seq = []

# --- Run episode ---
for _ in range(1000):
env.act(types_np.sample(env.ac_space, bshape=(env.num,)))
_ , obs, first = env.observe()
observations_seq.append(obs["rgb"])
if first:
break

# --- Save episode ---
if len(observations_seq) >= args.min_episode_length:
observations_data = np.concatenate(observations_seq, axis=0).astype(np.uint8)
episode_path = os.path.join(args.output_dir, split, f"episode_{i}.array_record")

# --- Save as ArrayRecord ---
writer = ArrayRecordWriter(str(episode_path), "group_size:1")
record = {"raw_video": observations_data.tobytes(), "sequence_length": len(observations_seq)}
writer.write(pickle.dumps(record))
writer.close()

episode_metadata.append({"path": str(episode_path), "length": len(observations_seq)})
print(f"Episode {i} completed, length: {len(observations_seq)}")
i += 1
else:
print(f"Episode too short ({len(observations_seq)}), resampling...")
print(f"Done generating {split} split")
return episode_metadata


def main():
# Set random seed and create dataset directories
np.random.seed(args.seed)
output_dir = Path(args.output_dir)
(output_dir / "train").mkdir(parents=True, exist_ok=True)
(output_dir / "val").mkdir(parents=True, exist_ok=True)
(output_dir / "test").mkdir(parents=True, exist_ok=True)

# --- Generate episodes ---
train_episode_metadata = generate_episodes(args.num_episodes_train, "train")
val_episode_metadata = generate_episodes(args.num_episodes_val, "val")
test_episode_metadata = generate_episodes(args.num_episodes_test, "test")

# --- Save metadata ---
metadata = {
"env": "coinrun",
"num_episodes_train": args.num_episodes_train,
"num_episodes_val": args.num_episodes_val,
"num_episodes_test": args.num_episodes_test,
"avg_episode_len_train": np.mean([ep["length"] for ep in train_episode_metadata]),
"avg_episode_len_val": np.mean([ep["length"] for ep in val_episode_metadata]),
"avg_episode_len_test": np.mean([ep["length"] for ep in test_episode_metadata]),
"episode_metadata_train": train_episode_metadata,
"episode_metadata_val": val_episode_metadata,
"episode_metadata_test": test_episode_metadata,

}
with open(output_dir / "metadata.json", "w") as f:
json.dump(metadata, f)

print(f"Done generating dataset.")

if __name__ == "__main__":
main()
125 changes: 0 additions & 125 deletions input_pipeline/preprocess/npy_to_tfrecord.py

This file was deleted.

130 changes: 130 additions & 0 deletions input_pipeline/preprocess/pngs_to_array_records.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os
import numpy as np
from PIL import Image
import tyro
from dataclasses import dataclass
import pickle
import json
import multiprocessing as mp
from array_record.python.array_record_module import ArrayRecordWriter


@dataclass
class Args:
input_path: str
output_path: str
env_name: str
original_fps: int = 60
target_fps: int = 10
target_width: int = 64

def preprocess_pngs(input_dir, output_path, original_fps, target_fps, target_width=None):
print(f"Processing PNGs in {input_dir}")
try:
png_files = sorted([
f for f in os.listdir(input_dir)
if f.lower().endswith('.png')
], key=lambda x: int(os.path.splitext(x)[0]))

if not png_files:
print(f"No PNG files found in {input_dir}")
return {"path": input_dir, "length": 0}

# Downsample indices
n_total = len(png_files)
if original_fps == target_fps:
selected_indices = np.arange(n_total)
else:
n_target = int(np.floor(n_total * target_fps / original_fps))
selected_indices = np.linspace(0, n_total-1, n_target, dtype=int)

selected_files = [png_files[i] for i in selected_indices]

# Load images
frames = []
for fname in selected_files:
img = Image.open(os.path.join(input_dir, fname)).convert("RGB")
if target_width is not None:
w, h = img.size # PIL gives (width, height)
if w != target_width:
target_height = int(round(h * (target_width / float(w))))
resample_filter = Image.LANCZOS
img = img.resize((target_width, target_height), resample=resample_filter)
frames.append(np.array(img))

frames = np.stack(frames, axis=0) # (n_frames, H, W, 3)
environment = os.path.basename(os.path.dirname(input_dir))
episode_id = os.path.basename(input_dir)
# Write to array_record
os.makedirs(output_path, exist_ok=True)
out_file = os.path.join(
output_path,
f"{environment}_{episode_id}.array_record"
)
writer = ArrayRecordWriter(str(out_file), "group_size:1")
record = {"raw_video": frames.tobytes(),
"environment": environment,
"sequence_length": frames.shape[0]}
writer.write(pickle.dumps(record))
writer.close()
print(f"Saved {frames.shape[0]} frames to {out_file}")
return {"path": input_dir, "length": frames.shape[0]}
except Exception as e:
print(f"Error processing {input_dir}: {e}")
return {"path": input_dir, "length": 0}

def main():
args = tyro.cli(Args)
os.makedirs(args.output_path, exist_ok=True)
print(f"Output path: {args.output_path}")

games = [
os.path.join(args.input_path, d)
for d in os.listdir(args.input_path)
if os.path.isdir(os.path.join(args.input_path, d))
]
episodes = [
os.path.join(game, d)
for game in games
for d in os.listdir(game)
]

results = []
num_processes = mp.cpu_count()
print(f"Number of processes: {num_processes}")
pool_args = [
(episode, args.output_path, args.original_fps, args.target_fps, args.target_width)
for episode in episodes
]
with mp.Pool(processes=num_processes) as pool:
for result in pool.starmap(preprocess_pngs, pool_args):
results.append(result)

print("Done converting png to array_record files")

# count the number of failed videos
failed_videos = [result for result in results if result["length"] == 0]
short_videos = [result for result in results if result["length"] < 1600]
num_successful_videos = len(results) - len(failed_videos) - len(short_videos)
print(f"Number of failed videos: {len(failed_videos)}")
print(f"Number of short videos: {len(short_videos)}")
print(f"Number of successful videos: {num_successful_videos}")
print(f"Number of total videos: {len(results)}")

metadata = {
"env": args.env_name,
"total_videos": len(results),
"num_successful_videos": len(results) - len(failed_videos) - len(short_videos),
"num_failed_videos": len(failed_videos),
"num_short_videos": len(short_videos),
"avg_episode_len": np.mean([ep["length"] for ep in results]),
"episode_metadata": results,
}

with open(os.path.join(args.output_path, "metadata.json"), "w") as f:
json.dump(metadata, f)

print("Done.")

if __name__ == "__main__":
main()
Loading