Skip to content

Commit 88e01aa

Browse files
committed
Runtime 1933 ms (Top 34.92%) | Memory 47.0 MB (Top 66.67%)
1 parent 46b3674 commit 88e01aa

File tree

1 file changed

+37
-43
lines changed

1 file changed

+37
-43
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,39 @@
1-
// Runtime: 2256 ms (Top 53.49%) | Memory: 63.7 MB (Top 23.26%)
2-
from functools import lru_cache
31
class Solution:
42
def getMaxGridHappiness(self, m: int, n: int, introvertsCount: int, extrovertsCount: int) -> int:
5-
def calc(x, y): #calculate interaction
6-
if x == 0 or y == 0:
7-
return 0
8-
if x == 1 and y == 1:
9-
return -60
10-
if x == 2 and y == 2:
11-
return 40
12-
return -10
13-
n3, highest = 3 ** n, 3 ** (n - 1)
14-
mask_all, rolling = {}, {}
15-
for mask in range(n3): #get every state in ternary expression
16-
tmp = mask
17-
ternary = []
18-
for _ in range(n):
19-
ternary.append(tmp % 3)
20-
tmp //= 3
21-
mask_all[mask] = ternary[::-1]
22-
rolling[mask] = [ #get the three states after a slide
23-
mask % highest * 3 + 0,
24-
mask % highest * 3 + 1,
25-
mask % highest * 3 + 2
26-
]
27-
@lru_cache(None)
28-
def dfs(pos, borderline, i, e):
29-
if pos == m * n or not i and not e: #exit of dfs
30-
return 0
31-
x, y = divmod(pos, n)
32-
#put 0 in pos
33-
best = dfs(pos + 1, rolling[borderline][0], i, e)
34-
#put 1 in pos
35-
if i > 0:
36-
best = max(best, 120 + calc(1, mask_all[borderline][0]) +
37-
(0 if y == 0 else calc(1, mask_all[borderline][n - 1])) +
38-
dfs(pos + 1, rolling[borderline][1], i - 1, e))
39-
#put 2 in pos
40-
if e > 0:
41-
best = max(best, 40 + calc(2, mask_all[borderline][0]) +
42-
(0 if y == 0 else calc(2, mask_all[borderline][n - 1])) +
43-
dfs(pos + 1, rolling[borderline][2], i, e - 1))
44-
return best
45-
return dfs(0, 0, introvertsCount, extrovertsCount)
3+
4+
@cache
5+
def fn(prev, i, j, intro, extro):
6+
"""Return max grid happiness at (i, j)."""
7+
if i == m: return 0 # no more position
8+
if j == n: return fn(prev, i+1, 0, intro, extro)
9+
if intro == extro == 0: return 0
10+
11+
prev0 = prev[:j] + (0,) + prev[j+1:]
12+
ans = fn(prev0, i, j+1, intro, extro)
13+
if intro:
14+
val = 120
15+
if i and prev[j]: # neighbor from above
16+
val -= 30
17+
if prev[j] == 1: val -= 30
18+
else: val += 20
19+
if j and prev[j-1]: # neighbor from left
20+
val -= 30
21+
if prev[j-1] == 1: val -= 30
22+
else: val += 20
23+
prev0 = prev[:j] + (1,) + prev[j+1:]
24+
ans = max(ans, val + fn(prev0, i, j+1, intro-1, extro))
25+
if extro:
26+
val = 40
27+
if i and prev[j]:
28+
val += 20
29+
if prev[j] == 1: val -= 30
30+
else: val += 20
31+
if j and prev[j-1]:
32+
val += 20
33+
if prev[j-1] == 1: val -= 30
34+
else: val += 20
35+
prev0 = prev[:j] + (2,) + prev[j+1:]
36+
ans = max(ans, val + fn(prev0, i, j+1, intro, extro-1))
37+
return ans
38+
39+
return fn((0,)*n, 0, 0, introvertsCount, extrovertsCount)

0 commit comments

Comments
 (0)