Skip to content

Commit 4d076cd

Browse files
authored
Merge pull request #623 from AswinJ1/patch-1
Create RL_game_env.py
2 parents f339191 + adab23f commit 4d076cd

File tree

1 file changed

+157
-0
lines changed

1 file changed

+157
-0
lines changed

python/RL_game_env.py

+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#reset
2+
#reward
3+
#play(action) ->action
4+
import pygame
5+
import random
6+
from enum import Enum
7+
from collections import namedtuple
8+
import numpy as np
9+
10+
pygame.init()
11+
#ont = pygame.font.Font('arial.ttf', 25)
12+
font = pygame.font.SysFont('arial', 25)
13+
14+
class Direction(Enum):
15+
RIGHT = 1
16+
LEFT = 2
17+
UP = 3
18+
DOWN = 4
19+
20+
Point = namedtuple('Point', 'x, y')
21+
22+
# rgb colors
23+
WHITE = (255, 255, 255)
24+
RED = (200,0,0)
25+
BLUE1 = (0, 0, 255)
26+
BLUE2 = (0, 100, 255)
27+
BLACK = (0,0,0)
28+
29+
BLOCK_SIZE = 20
30+
SPEED = 40
31+
32+
class SnakeGameAI:
33+
34+
def __init__(self, w=640, h=480):
35+
self.w = w
36+
self.h = h
37+
# init display
38+
self.display = pygame.display.set_mode((self.w, self.h))
39+
pygame.display.set_caption('Snake')
40+
self.clock = pygame.time.Clock()
41+
self.reset()
42+
43+
44+
def reset(self):
45+
# init game state
46+
self.direction = Direction.RIGHT
47+
48+
self.head = Point(self.w/2, self.h/2)
49+
self.snake = [self.head,
50+
Point(self.head.x-BLOCK_SIZE, self.head.y),
51+
Point(self.head.x-(2*BLOCK_SIZE), self.head.y)]
52+
53+
self.score = 0
54+
self.food = None
55+
self._place_food()
56+
self.frame_iteration = 0
57+
58+
59+
def _place_food(self):
60+
x = random.randint(0, (self.w-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE
61+
y = random.randint(0, (self.h-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE
62+
self.food = Point(x, y)
63+
if self.food in self.snake:
64+
self._place_food()
65+
66+
67+
def play_step(self, action):
68+
self.frame_iteration += 1
69+
# 1. collect user input
70+
for event in pygame.event.get():
71+
if event.type == pygame.QUIT:
72+
pygame.quit()
73+
quit()
74+
75+
# 2. move
76+
self._move(action) # update the head
77+
self.snake.insert(0, self.head)
78+
79+
# 3. check if game over
80+
reward = 0
81+
game_over = False
82+
if self.is_collision() or self.frame_iteration > 100*len(self.snake):
83+
game_over = True
84+
reward = -10
85+
return reward, game_over, self.score
86+
87+
# 4. place new food or just move
88+
if self.head == self.food:
89+
self.score += 1
90+
reward = 10
91+
self._place_food()
92+
else:
93+
self.snake.pop()
94+
95+
# 5. update ui and clock
96+
self._update_ui()
97+
self.clock.tick(SPEED)
98+
# 6. return game over and score
99+
return reward, game_over, self.score
100+
101+
102+
def is_collision(self, pt=None):
103+
if pt is None:
104+
pt = self.head
105+
# hits boundary
106+
if pt.x > self.w - BLOCK_SIZE or pt.x < 0 or pt.y > self.h - BLOCK_SIZE or pt.y < 0:
107+
return True
108+
# hits itself
109+
if pt in self.snake[1:]:
110+
return True
111+
112+
return False
113+
114+
115+
def _update_ui(self):
116+
self.display.fill(BLACK)
117+
118+
for pt in self.snake:
119+
pygame.draw.rect(self.display, BLUE1, pygame.Rect(pt.x, pt.y, BLOCK_SIZE, BLOCK_SIZE))
120+
pygame.draw.rect(self.display, BLUE2, pygame.Rect(pt.x+4, pt.y+4, 12, 12))
121+
122+
pygame.draw.rect(self.display, RED, pygame.Rect(self.food.x, self.food.y, BLOCK_SIZE, BLOCK_SIZE))
123+
124+
text = font.render("Score: " + str(self.score), True, WHITE)
125+
self.display.blit(text, [0, 0])
126+
pygame.display.flip()
127+
128+
129+
def _move(self, action):
130+
# [straight, right, left]
131+
132+
clock_wise = [Direction.RIGHT, Direction.DOWN, Direction.LEFT, Direction.UP]
133+
idx = clock_wise.index(self.direction)
134+
135+
if np.array_equal(action, [1, 0, 0]):
136+
new_dir = clock_wise[idx] # no change
137+
elif np.array_equal(action, [0, 1, 0]):
138+
next_idx = (idx + 1) % 4
139+
new_dir = clock_wise[next_idx] # right turn r -> d -> l -> u
140+
else: # [0, 0, 1]
141+
next_idx = (idx - 1) % 4
142+
new_dir = clock_wise[next_idx] # left turn r -> u -> l -> d
143+
144+
self.direction = new_dir
145+
146+
x = self.head.x
147+
y = self.head.y
148+
if self.direction == Direction.RIGHT:
149+
x += BLOCK_SIZE
150+
elif self.direction == Direction.LEFT:
151+
x -= BLOCK_SIZE
152+
elif self.direction == Direction.DOWN:
153+
y += BLOCK_SIZE
154+
elif self.direction == Direction.UP:
155+
y -= BLOCK_SIZE
156+
157+
self.head = Point(x, y)

0 commit comments

Comments
 (0)