I'm trying to implement NEAT algorithm to a snake game i coded. However, it doesn't work, and the snakes just go around non-stop. The last thing i try, was to implement some sort of "danger or food check", to tell the snake if the food was near or not. It looks in 3 different ways by 8 blocks.
import pygame as pg
import os, neat, random
SCREEN_HEIGHT = 650
SCREEN_WIDTH = SCREEN_HEIGHT - 50
WIDTH = SCREEN_WIDTH
HEIGHT = SCREEN_HEIGHT - 50
BLOCKSIZE = 30
COLS = WIDTH // BLOCKSIZE
ROWS = HEIGHT // BLOCKSIZE
class Snake:
def __init__(self, screen, color):
self.screen = screen
self.location =
pg.math.Vector2(COLS/2,ROWS/2)
self.moving_to = pg.math.Vector2(1,0)
self.length = 1
self.positions = [self.location]
self.alive = True
self.hungry = 0
self.color = color
self.food = Food(screen, color)
self.distanceHistory = []
def reset(self):
self.location = pg.math.Vector2(COLS/2,ROWS/2)
self.length = 1
self.alive = True
self.positions = [self.location]
def render(self):
for position in self.positions:
pg.draw.rect(self.screen, self.color, (position[0] * BLOCKSIZE, position[1] * BLOCKSIZE , BLOCKSIZE, BLOCKSIZE))
self.food.render()
def move(self):
new_position = self.location + self.moving_to
if self.alive:
self.location += self.moving_to
#Move and insert new position to keep track of the tail
self.positions.insert(0, new_position)
if len(self.positions) > self.length:
#Pop if not necessary the oldest position saved
self.positions.pop()
def threeWaystoDie(self):
self.hungry += 1
new_position = self.location + self.moving_to
#If touches it's tail
if len(self.positions) and new_position in self.positions[2:]:
self.alive = False
#If touches a wall
elif not (0 <= self.location[0] <= COLS - 1 and 0 <= self.location[1] <= ROWS - 1):
self.alive = False
if self.hungry >= 200:
self.alive = False
def checkDangerOrFood(self, blockstoCheck):
straight = []
right = []
left = []
#Each moving to, has 3 possible directions to check
#If wall or tail near
#RIGHT
#straight
if (self.moving_to[0], self.moving_to[1]) == (1, 0):
for i in range(blockstoCheck):
new_position = self.location + self.moving_to * i
if new_position[0] >= COLS - 1 or new_position in self.positions[2:]:
straight.append(-1)
elif new_position == self.food.location:
straight.append(1)
else:
straight.append(0)
#up
for i in range(blockstoCheck):
new_position = self.location + pg.math.Vector2(0, -1) * i
if new_position[1] <= 0 or new_position in self.positions[2:]:
left.append(-1)
elif new_position == self.food.location:
left.append(1)
else:
left.append(0)
#down
for i in range(blockstoCheck):
new_position = self.location + pg.math.Vector2(0, 1) * i
if new_position[1] <= 0 or new_position in self.positions[2:]:
right.append(-1)
elif new_position == self.food.location:
right.append(1)
else:
right.append(0)
#LEFT
#straight
if (self.moving_to[0], self.moving_to[1]) == (-1, 0):
for i in range(blockstoCheck):
new_position = self.location + self.moving_to * i
if new_position[0] <= 0 or new_position in self.positions[2:]:
straight.append(-1)
elif new_position == self.food.location:
straight.append(1)
else:
straight.append(0)
#down
for i in range(blockstoCheck):
new_position = self.location + pg.math.Vector2(0, 1) * i
if new_position[1] >= ROWS - 1 or new_position in self.positions[2:]:
left.append(-1)
elif new_position == self.food.location:
left.append(1)
else:
left.append(0)
#up
for i in range(blockstoCheck):
new_position = self.location + pg.math.Vector2(0, -1) * i
if new_position[1] <= 0 or new_position in self.positions[2:]:
right.append(-1)
elif new_position == self.food.location:
right.append(1)
else:
right.append(0)
#UP
#straight
if (self.moving_to[0], self.moving_to[1]) == (0, -1):
for i in range(blockstoCheck):
new_position = self.location + self.moving_to * i
if new_position[1] <= 0 or new_position in self.positions[2:]:
straight.append(-1)
elif new_position == self.food.location:
straight.append(1)
else:
straight.append(0)
#left
for i in range(blockstoCheck):
new_position = self.location + pg.math.Vector2(-1, 0) * i
if new_position[0] <= 0 or new_position in self.positions[2:]:
left.append(-1)
elif new_position == self.food.location:
left.append(1)
else:
left.append(0)
#right
for i in range(blockstoCheck):
new_position = self.location + pg.math.Vector2(1, 0) * i
if new_position [0] >= COLS - 1 or new_position in self.positions[2:]:
right.append(-1)
elif new_position == self.food.location:
right.append(1)
else:
right.append(0)
#DOWN
#straight
if (self.moving_to[0], self.moving_to[1]) == (0, 1):
for i in range(blockstoCheck):
new_position = self.location + self.moving_to * i
if new_position[1] >= ROWS - 1 or new_position in self.positions[2:]:
straight.append(-1)
elif new_position == self.food.location:
straight.append(1)
else:
straight.append(0)
#left
for i in range(blockstoCheck):
new_position = self.location + pg.math.Vector2(0, 1) * i
if new_position[0] <= 0 or new_position in self.positions[2:]:
left.append(-1)
elif new_position == self.food.location:
left.append(1)
else:
left.append(0)
#right
for i in range(blockstoCheck):
new_position = self.location + pg.math.Vector2(0, -1) * i
if new_position[0] >= COLS - 1 or new_position in self.positions[2:]:
right.append(-1)
elif new_position == self.food.location:
right.append(1)
else:
right.append(0)
return straight, left, right
class Food:
def __init__(self, screen, color):
self.screen = screen
self.location = pg.math.Vector2(randint(0, COLS - 1), randint(0, ROWS - 1))
self.color = color
def render(self):
pg.draw.circle(self.screen, self.color, (self.location[0] * BLOCKSIZE + (BLOCKSIZE//2), self.location[1] * BLOCKSIZE+ (BLOCKSIZE//2)), 10)
def reset(self, positionsOfSnake):
self.location = pg.math.Vector2(randint(0, COLS - 1), randint(0, ROWS - 1))
#Not to spawn on the snake's tail
while self.location in positionsOfSnake:
self.location = pg.math.Vector2(randint(0, COLS - 1), randint(0, ROWS - 1))
GEN = 0
colors = ["red", "yellow", "purple", "white", "green", "blue"]
def draw_stuff(screen, gen, snakes):
font = pg.font.SysFont("timesnewroman", 20, bold=True)
label = font.render(f"Gen: {gen} Snakes Alive: {snakes}", 1, "red")
screen.blit(label, (WIDTH / 2 - (label.get_width()/2),
(SCREEN_HEIGHT - 30) - label.get_height()/2))
def draw_lines(screen):
# Lines
for x in range(COLS + 1):
pg.draw.line(screen, "red", (x * BLOCKSIZE, 0),
(x * BLOCKSIZE, HEIGHT))
for x in range(ROWS + 1):
pg.draw.line(screen, "red", (0, x * BLOCKSIZE), (WIDTH, x * BLOCKSIZE))
def draw_grid(screen):
# Contour
pg.draw.rect(screen, "red", (0, 0, WIDTH, HEIGHT), 2)
#draw_lines(screen)
def main(genomes, config):
global HUMAN, GEN
screen = pg.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
pg.init()
clock = pg.time.Clock()
snakes = []
nets = []
ge = []
for _, g in genomes:
net = neat.nn.FeedForwardNetwork.create(g, config)
nets.append(net)
color = random.choice(colors)
snakes.append(Snake(screen, color))
g.fitness = 0
ge.append(g)
score = 0
run = True
while run:
pg.display.update()
clock.tick(10)
screen.fill((30, 30, 30))
for event in pg.event.get():
if event.type == pg.QUIT:
run = False
pg.quit()
quit()
draw_grid(screen)
draw_stuff(screen, GEN, len(snakes))
for x, snake in enumerate(snakes):
#ge[x].fitness += 0.1
snakes[x].move()
snakes[x].render()
snakes[x].threeWaystoDie()
distance2Food = snakes[x].location.distance_to(snakes[x].food.location)
blockstoCheck = 8
dangers = snakes[x].checkDangerOrFood(blockstoCheck)
straight = dangers[0]
left = dangers[1]
right = dangers[2]
inputs = (
#snakes[x].location[0],
#snakes[x].location[1],
snakes[x].moving_to[0],
snakes[x].moving_to[1],
# snakes[x].hungry,
distance2Food,
snakes[x].location.angle_to(snakes[x].food.location)
#snakes[x].food.location[0],
#snakes[x].food.location[1],
)
#Adding dangers to inputs (3x8)
for i in range(blockstoCheck):
inputs = inputs + (straight[i], left[i], right[i])
output = nets[x].activate((inputs))
indexOfMaxValue = output.index(max(output))
if (snake.moving_to[0], snake.moving_to[1]) == (1, 0):
if indexOfMaxValue == 0:
turn = snakes[x].moving_to
if indexOfMaxValue == 1:
turn = pg.math.Vector2(0, -1)
if indexOfMaxValue == 2:
turn = pg.math.Vector2(0, 1)
if (snake.moving_to[0], snake.moving_to[1]) == (-1, 0):
if indexOfMaxValue == 0:
turn = snakes[x].moving_to
if indexOfMaxValue == 1:
turn = pg.math.Vector2(0, 1)
if indexOfMaxValue == 2:
turn = pg.math.Vector2(0, -1)
if (snake.moving_to[0], snake.moving_to[1]) == (0, 1):
if indexOfMaxValue == 0:
turn = snakes[x].moving_to
if indexOfMaxValue == 1:
turn = pg.math.Vector2(1, 0)
if indexOfMaxValue == 2:
turn = pg.math.Vector2(-1, 0)
if (snake.moving_to[0], snake.moving_to[1]) == (0, -1):
if indexOfMaxValue == 0:
turn = snakes[x].moving_to
if indexOfMaxValue == 1:
turn = pg.math.Vector2(-1, 0)
if indexOfMaxValue == 2:
turn = pg.math.Vector2(1, 0)
snakes[x].moving_to = turn
if len(snakes) <= 0:
run = False
GEN += 1
break
for x, snake in enumerate(snakes):
if snakes[x].location.distance_to(snakes[x].food.location) <= 1.0:
snake.length += 1
score += 10
snakes[x].food.reset(snakes[x].positions)
snakes[x].hungry = -snake.length
ge[x].fitness += 10
#IF snake dies
if snakes[x].alive == False:
ge[x].fitness -= 1
snakes.pop(x)
nets.pop(x)
ge.pop(x)
def run_neat(config):
#pop = neat.Checkpointer.restore_checkpoint('neat-checkpoint-27')
pop = neat.Population(config)
pop.add_reporter(neat.StdOutReporter(True))
stats = neat.StatisticsReporter()
pop.add_reporter(stats)
pop.add_reporter(neat.Checkpointer(10))
winner = pop.run(main, 50)
if __name__ == "__main__":
pg.display.set_caption('Snake')
pg.font.init()
local_dir = os.path.dirname(__file__)
config_path = os.path.join(local_dir, "config.txt")
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
neat.DefaultSpeciesSet, neat.DefaultStagnation, config_path)
run_neat(config)
Here's also the config file for the neat algorithm
[NEAT]
fitness_criterion = mean
fitness_threshold = 400
pop_size = 250
reset_on_extinction = False
[DefaultStagnation]
species_fitness_func = max
max_stagnation = 20
species_elitism = 2
[DefaultReproduction]
elitism = 2
survival_threshold = 0.2
[DefaultGenome]
# node activation options
activation_default = tanh
activation_mutate_rate = 1.0
activation_options = tanh
# node aggregation options
aggregation_default = sum
aggregation_mutate_rate = 0.0
aggregation_options = sum
# node bias options
bias_init_mean = 3.0
bias_init_stdev = 1.0
bias_max_value = 30.0
bias_min_value = -30.0
bias_mutate_power = 0.5
bias_mutate_rate = 0.7
bias_replace_rate = 0.1
# genome compatibility options
compatibility_disjoint_coefficient = 1.0
compatibility_weight_coefficient = 0.5
# connection add/remove rates
conn_add_prob = 0.5
conn_delete_prob = 0.5
# connection enable options
enabled_default = True
enabled_mutate_rate = 0.04
feed_forward = True
initial_connection = full_direct
# node add/remove rates
node_add_prob = 0.2
node_delete_prob = 0.2
# network parameters
num_hidden = 3
num_inputs = 28
num_outputs = 3
# node response options
response_init_mean = 1.0
response_init_stdev = 0.0
response_max_value = 30.0
response_min_value = -30.0
response_mutate_power = 0.0
response_mutate_rate = 0.0
response_replace_rate = 0.0
# connection weight options
weight_init_mean = 0.0
weight_init_stdev = 1.0
weight_max_value = 30
weight_min_value = -30
weight_mutate_power = 0.5
weight_mutate_rate = 0.8
weight_replace_rate = 0.1
[DefaultSpeciesSet]
compatibility_threshold = 3.0
I would appreciate any advice u can give me, since i'm stuck as hell.
I've tried with different inputs and outputs, but nothing seems to work, which makes me think that there's an error somewhere else. Here's some of the outputs examples:
[0.9999999999999445, -0.999999999999569, 1.0]
[-0.9999999999497647, 1.0, 1.0]
[-1.0, 1.0, -1.0]
[-1.0, 1.0, 1.0]
[-1.0, 0.8234808233113396, 1.0]
[-0.9999999999991865, -0.9999999999999998, 1.0]