Skip to content

Commit 8dfdabf

Browse files
working streeteasy dynaTorch example test
1 parent 2493821 commit 8dfdabf

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

tests/test_dynatorch.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,64 @@ def test_titanic_classification_torch():
232232

233233
accuracy = accuracy_score(y_val_torch, pred_classes)
234234
print(f"DynaTorchModel Accuracy: {accuracy:.4f}")
235+
236+
237+
238+
# DynaTorch Example that doesn't suck
239+
from torch.optim import Adam
240+
241+
def test_streeteasy_regression_dynatorch():
242+
apartments_df = pd.read_csv("tests/street_easy_data/streeteasy.csv")
243+
244+
numerical_features = [
245+
"bedrooms", "bathrooms", "size_sqft", "min_to_subway",
246+
"floor", "building_age_yrs", "no_fee", "has_roofdeck",
247+
"has_washer_dryer", "has_doorman", "has_elevator",
248+
"has_dishwasher", "has_patio", "has_gym"
249+
]
250+
X = apartments_df[numerical_features].to_numpy()
251+
y = apartments_df["rent"].to_numpy()
252+
253+
X = torch.tensor(X, dtype=torch.float32)
254+
y = torch.tensor(y, dtype=torch.float32).view(-1, 1)
255+
256+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
257+
258+
layer_dims = [X_train.shape[1], 128, 64, 1]
259+
model = DynaTorchModel(layer_dims, task_type="regression")
260+
261+
loss_fn = model.loss_fn # MSELoss defined in DynaTorchModel
262+
optimizer = Adam(model.parameters(), lr=0.001)
263+
264+
num_epochs = 20000
265+
for epoch in range(num_epochs):
266+
model.train()
267+
optimizer.zero_grad()
268+
predictions = model(X_train)
269+
loss = loss_fn(predictions, y_train)
270+
loss.backward()
271+
optimizer.step()
272+
273+
# Logging progress
274+
if (epoch + 1) % 1000 == 0:
275+
print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")
276+
277+
model.eval()
278+
with torch.no_grad():
279+
test_predictions = model(X_test)
280+
test_loss = loss_fn(test_predictions, y_test)
281+
rmse = torch.sqrt(test_loss)
282+
283+
print(f"Test RMSE: {rmse.item():.4f}")
284+
285+
plt.figure(figsize=(10, 6))
286+
plt.scatter(y_test.numpy(), test_predictions.numpy(), alpha=0.5, color="blue", label="Predictions")
287+
288+
min_val, max_val = y_test.min().item(), y_test.max().item()
289+
plt.plot([min_val, max_val], [min_val, max_val], linestyle="--", color="pink", label="y = x")
290+
291+
plt.xlabel("Actual Rent")
292+
plt.ylabel("Predicted Rent")
293+
plt.title(f"Predicted vs Actual Rent (RMSE: {rmse.item():.2f})")
294+
plt.legend()
295+
plt.show()

0 commit comments

Comments
 (0)