import pygame
import pytmx
import math
from random import randint

# Constants
BACKGROUND = (20, 20, 20)
SCREEN_WIDTH = 640
SCREEN_HEIGHT = 480
MAP_COLLISION_LAYER = 1

# Pygame setup
pygame.init()
pygame.display.set_caption("Merged Game")
screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
clock = pygame.time.Clock()
running = True


# Classes
class Player(pygame.sprite.Sprite):
    def __init__(self, pos, group):
        super().__init__(group)
        self.idle_image = self.image = pygame.image.load("Sprites/Idle/Player_right.png").convert_alpha()
        self.rect = self.image.get_rect(center=pos)
        self.direction = pygame.math.Vector2()
        self.speed = 2.25
        self.time = 0

        # Health
        self.current_health = 1000
        self.maximum_health = 1000
        self.health_bar_length = 500
        self.health_ratio = self.maximum_health / self.health_bar_length
        self.target_health = 1000
        self.health_change_speed = 5

        # Shooting
        self.can_shoot = True
        self.shoot_cooldown_time = 100  # Milliseconds
        self.shoot_cooldown = 0

        # Animations
        self.image_direction = 1
        self.animation_timer = 0
        self.current_sprite = 0
        self.running_sprites = []
        for i in range(6):
            self.running_sprites.append(pygame.image.load(f"Sprites/Running/Player_run_right_{i + 1}.png").convert_alpha())

    def get_damage(self, amount):
        if self.target_health > 0:
            self.target_health -= amount
        if self.target_health <= 0:
            self.target_health = 0

    def get_health(self, amount):
        if self.target_health < self.maximum_health:
            self.target_health += amount
        if self.target_health >= self.maximum_health:
            self.target_health = self.maximum_health

    def health_bar(self):
        self.current_health += (self.target_health - self.current_health) / 10

        # If self.current_health < self.target_health
        transition_colour = (0, 255, 0)
        health_bar_width = self.current_health / self.health_ratio
        transition_bar_width = self.target_health / self.health_ratio

        if self.current_health > self.target_health:
            health_bar_width = self.target_health / self.health_ratio
            transition_bar_width = self.current_health / self.health_ratio
            transition_colour = (255, 255, 0)

        health_bar_rect = pygame.Rect(10, 45, health_bar_width, 25)
        transition_bar_rect = pygame.Rect(10, 45, transition_bar_width, 25)

        pygame.draw.rect(screen, transition_colour, transition_bar_rect)
        pygame.draw.rect(screen, (255, 0, 0), health_bar_rect)
        pygame.draw.rect(screen, (255, 255, 255), (10, 45, self.health_bar_length, 25), 4)

    def input(self):
        keys = pygame.key.get_pressed()
        speed_multiplier_x = 0
        speed_multiplier_y = 0

        if self.direction.x > 0:
            self.image_direction = 1
        elif self.direction.x < 0:
            self.image_direction = -1

        if keys[pygame.K_w]:
            speed_multiplier_y = -1
        elif keys[pygame.K_s]:
            speed_multiplier_y = 1

        if keys[pygame.K_a]:
            speed_multiplier_x = -1
        elif keys[pygame.K_d]:
            speed_multiplier_x = 1

        self.direction.y = (keys[pygame.K_s]) - (keys[pygame.K_w])
        self.direction.x = (keys[pygame.K_d]) - (keys[pygame.K_a])

        if self.direction.x != 0 and self.direction.y != 0:
            self.direction.normalize_ip()

        self.direction.x *= self.speed
        self.direction.y *= self.speed

        if pygame.mouse.get_pressed()[0] and self.can_shoot:
            direction = -(pygame.math.Vector2(self.rect.topleft) - pygame.mouse.get_pos() - camera_group.offset)
            direction = pygame.math.Vector2.normalize(direction)
            new_projectile = Projectile(self.rect.center, direction, camera_group)
            self.can_shoot = False
            self.shoot_cooldown = pygame.time.get_ticks()

    def update(self):
        self.input()
        self.health_bar()
        self.rect.center += self.direction * self.speed

        now = pygame.time.get_ticks()
        if now - self.shoot_cooldown > self.shoot_cooldown_time:
            self.can_shoot = True

        if now - self.animation_timer > 50:
            if pygame.Vector2.length(self.direction) == 0:
                self.image = self.idle_image
                self.current_sprite = 0
            else:
                self.animation_timer = pygame.time.get_ticks()
                self.current_sprite += 1
                if self.current_sprite > len(self.running_sprites) - 1:
                    self.current_sprite = 0
                self.image = self.running_sprites[self.current_sprite]


class Projectile(pygame.sprite.Sprite):
    def __init__(self, pos, direction, group):
        super().__init__(group)
        self.image = pygame.image.load("Sprites/bullet.png").convert_alpha()
        self.image_direction = 1

        self.direction = direction
        self.speed = 15
        angle = math.degrees(math.atan2(-self.direction.y, self.direction.x))

        self.image = pygame.transform.rotate(self.image, angle)
        self.rect = self.image.get_rect(center=pos)
        self.rect.center += self.direction * 50

        self.life_timer = 10000  # Milliseconds
        self.spawned_time = pygame.time.get_ticks()

    def collision(self):
        for enemy in enemies:
            distance = pygame.math.Vector2.length(pygame.math.Vector2(enemy.rect.center) - self.rect.center)
            if distance < 20:
                enemy.damage(1)
                self.kill()
                break

    def update(self):
        self.rect.center += self.direction * self.speed
        self.collision()

        now = pygame.time.get_ticks()
        if now - self.spawned_time > self.life_timer:
            self.kill()

class Arm(pygame.sprite.Sprite):
    def __init__(self, pos, group):
        super().__init__(group)
        self.base_image = pygame.image.load("Sprites/minigun.png").convert_alpha()
        self.base_image = pygame.transform.scale(self.base_image, (40, 15))
        
        self.image_direction = 1
        
        self.direction = pygame.Vector2(0, 0)
        self.speed = 10
        angle = math.degrees(math.atan2(-self.direction.y, self.direction.x)) 
        self.image = pygame.transform.rotate(self.base_image, angle)
        self.rect = self.image.get_rect(center = pos)

    def update(self):
        self.rect.center = player.rect.center
        self.direction = -(pygame.math.Vector2(self.rect.topleft) - pygame.mouse.get_pos() - camera_group.offset)
        self.direction = pygame.math.Vector2.normalize(self.direction)
        angle = math.degrees(math.atan2(-self.direction.y, self.direction.x))
        if(abs(angle)<90):
            self.image_direction = 1
        else:
            self.image_direction = -1
        self.rect.center -= pygame.Vector2(math.sin(math.radians(angle-90))*20, math.cos(math.radians(angle-90))*20)
        if(abs(angle)>90):
            angle = math.degrees(math.atan2(self.direction.y, self.direction.x))+180
        self.image = pygame.transform.rotate(self.base_image, angle)


class Enemy(pygame.sprite.Sprite):
    def __init__(self, pos, group):
        super().__init__(group)
        self.image = pygame.surface.Surface((20, 20))
        self.image.fill("red")
        self.rect = self.image.get_rect(center=pos)
        self.pos = self.rect.center
        self.direction = pygame.math.Vector2()
        self.speed = 2
        self.health = 2

        self.image_direction = 1

    def damage(self, damage):
        self.health -= 1
        if self.health <= 0:
            enemies.remove(self)
            self.kill()

    def update(self):
        self.direction = pygame.math.Vector2(player.rect.center) - self.rect.center
        if pygame.math.Vector2.length(self.direction) < 20:
            player.get_damage(200)
            self.kill()
            enemies.remove(self)
            player.time = 0
        self.pos += self.direction.normalize() * self.speed
        self.rect.center = (round(self.pos.x), round(self.pos.y))


class CameraGroup(pygame.sprite.Group):
    def __init__(self):
        super().__init__()
        self.display_surface = pygame.display.get_surface()

        self.offset = pygame.math.Vector2(300, 100)
        self.half_w = self.display_surface.get_size()[0] // 2
        self.half_h = self.display_surface.get_size()[1] // 2

    def center_target_camera(self, target):
        self.offset.x = target.rect.centerx - self.half_w
        self.offset.y = target.rect.centery - self.half_h

    def custom_draw(self, player):
        self.center_target_camera(player)

        for sprite in sorted(self.sprites(), key=lambda sprite: sprite.rect.centery):
            offset_pos = sprite.rect.topleft - self.offset
            if sprite.image_direction == -1:
                self.display_surface.blit(pygame.transform.flip(sprite.image, True, False), offset_pos)
            else:
                self.display_surface.blit(sprite.image, offset_pos)


class Game:
    def __init__(self):
        self.overlay = pygame.image.load("overlay.png")
        self.currentLevelNumber = 0
        self.levels = []
        self.levels.append(Level(fileName="level1.tmx"))
        self.currentLevel = self.levels[self.currentLevelNumber]

    def draw(self, screen):
        screen.fill(BACKGROUND)
        self.currentLevel.draw(screen)
        screen.blit(self.overlay, [0, 0])
        pygame.display.flip()


class Level:
    def __init__(self, fileName):
        self.mapObject = pytmx.load_pygame(fileName)
        self.layers = []
        for layer in range(len(self.mapObject.layers)):
            self.layers.append(Layer(index=layer, mapObject=self.mapObject))

    def draw(self, screen):
        for layer in self.layers:
            layer.draw(screen)


class Layer:
    def __init__(self, index, mapObject):
        self.index = index
        self.tiles = pygame.sprite.Group()
        self.mapObject = mapObject
        for x in range(self.mapObject.width):
            for y in range(self.mapObject.height):
                img = self.mapObject.get_tile_image(x, y, self.index)
                if img:
                    self.tiles.add(Tile(image=img, x=(x * self.mapObject.tilewidth), y=(y * self.mapObject.tileheight)))

    def draw(self, screen):
        self.tiles.draw(screen)


class Tile(pygame.sprite.Sprite):
    def __init__(self, image, x, y):
        pygame.sprite.Sprite.__init__(self)
        self.image = image
        self.rect = self.image.get_rect()
        self.rect.x = x
        self.rect.y = y


def spawn_enemy():
    if len(enemies) < 20:
        random_x = randint(-1000, 1000)
        random_y = randint(-1000, 1000)
        enemies.append(Enemy((player.rect.centerx + random_x, player.rect.centery + random_y), camera_group))


# Main game setup
game = Game()
camera_group = CameraGroup()
player = Player((500, 500), camera_group)
Arm((0, 0), camera_group)
enemies = []  # Define enemies list here
pygame.time.set_timer(pygame.USEREVENT + 1, 100)  # Event for spawning enemies


# Main game loop
while running:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False

        if event.type == pygame.KEYDOWN:
            if event.key == pygame.K_UP:
                player.get_health(200)
            elif event.key == pygame.K_DOWN:
                player.get_damage(200)

    if player.time > 3:
        if player.current_health < 999:
            player.get_health(200)
            player.time -= 1.5

    screen.fill(BACKGROUND)
    game.draw(screen)
    camera_group.update()
    camera_group.custom_draw(player)
    pygame.display.flip()
    clock.tick(60)
    player.time += 0.016