Skip to content

Commit 56ae4cf

Browse files
authored
Merge pull request #65 from jeffin07/master
Simple Linear Regression
2 parents 27f73c0 + 913add4 commit 56ae4cf

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
class SimpleLinearRegression:
4+
def __init__(self):
5+
self.w_0 = 0
6+
self.w_1 = 0
7+
8+
def fit(self,x,y):
9+
n = len(x)
10+
m_x,m_y = np.mean(x),np.mean(y)
11+
self.w_1 = np.sum(y * x - n * m_x * m_y) / np.sum(x * x - n * m_x * m_x)
12+
self.w_0 = m_y - (self.w_1 * m_x)
13+
return(self.w_0,self.w_1)
14+
15+
def predict(self,x_test):
16+
y_pred = self.w_0 + self.w_1 * x_test
17+
return y_pred
18+
19+
def plot(self,x,y):
20+
plt.scatter(x,y,color="r",marker="o")
21+
line=self.w_0 + self.w_1 * x
22+
plt.plot(x,line,color="g")
23+
plt.xlabel("X")
24+
plt.ylabel("Y")
25+
plt.show()
26+
27+
if __name__ == "__main__":
28+
x=np.array([1,2,3,4,5,6,7,8,9,10])
29+
y=np.array([300,400,500,600,700,800,900,1000,1200,1400])
30+
model=SimpleLinearRegression()
31+
model.fit(x,y)
32+
print(model.predict(11))
33+
model.plot(x,y)

0 commit comments

Comments
 (0)