Skip to content

Commit

Permalink
style: 🎨 re-format
Browse files Browse the repository at this point in the history
  • Loading branch information
xihuai18 committed Jun 24, 2024
1 parent 378c121 commit 4b55300
Show file tree
Hide file tree
Showing 163 changed files with 288 additions and 463 deletions.
4 changes: 2 additions & 2 deletions assets/policy_config.example
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@
'use_value_active_masks': True,
'use_valuenorm': True,
'use_wandb': False,
'user_name': 'gaojiaxuan',
'user_name': 'your wandb name',
'value_loss_coef': 1,
'w0': '1,1,1,1',
'w1': '1,1,1,1',
'wandb_name': 'samji2000',
'wandb_name': 'your wandb name',
'wandb_tags': [],
'weight_decay': 0},
Box(0.0, inf, (9, 5, 20), float32),
Expand Down
8 changes: 5 additions & 3 deletions install_grf.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ sudo apt-get install git cmake build-essential libgl1-mesa-dev libsdl2-dev \
libsdl2-image-dev libsdl2-ttf-dev libsdl2-gfx-dev libboost-all-dev \
libdirectfb-dev libst-dev mesa-utils xvfb x11vnc python3-pip -y

sudo apt reinstall libffi7

## build
pip install --user wheel==0.38.0 setuptools==65.5.0 six
conda install anaconda::py-boost -y
pip install wheel setuptools six
# conda install anaconda::py-boost -y

### dependences
# cd /usr/lib/x86_64-linux-gnu/
Expand All @@ -32,4 +34,4 @@ pip install gfootball

### test
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libffi.so.7
python3 -c "import gfootball.env as football_env; env = football_env.create_environment('academy_3_vs_1_with_keeper'); print(env.reset()); print(env.step([0]))"
python -c "import gfootball.env as football_env; env = football_env.create_environment('academy_3_vs_1_with_keeper'); print(env.reset()); print(env.step([0]))"
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os

import setuptools
from setuptools import find_packages, setup
from setuptools import setup


def get_version() -> str:
Expand Down
1 change: 0 additions & 1 deletion zsceval/algorithms/population/cole.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections import OrderedDict

import torch
from loguru import logger

from zsceval.algorithms.population.policy_pool import PolicyPool
from zsceval.algorithms.population.trainer_pool import TrainerPool
Expand Down
2 changes: 0 additions & 2 deletions zsceval/algorithms/population/policy_pool.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os
import pickle
import warnings
from pprint import pformat
from typing import Dict, List, Tuple

import numpy as np
import torch
import yaml
from loguru import logger

from zsceval.algorithms.population.utils import EvalPolicy
from zsceval.runner.shared.base_runner import make_trainer_policy_cls
Expand Down
4 changes: 2 additions & 2 deletions zsceval/algorithms/population/trainer_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from collections import OrderedDict, defaultdict
from copy import deepcopy
from typing import Any, Dict, Tuple
from typing import Any, Dict

import numpy as np
import torch
Expand Down Expand Up @@ -232,7 +232,7 @@ def insert_data(
if self.skip(trainer_name):
continue

trainer = self.trainer_pool[trainer_name]
self.trainer_pool[trainer_name]
buffer = self.buffer_pool[trainer_name]

(
Expand Down
9 changes: 0 additions & 9 deletions zsceval/algorithms/population/traj.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
import copy
import itertools
import logging
import os
import random
from collections import defaultdict
from typing import Dict, List

import numpy as np
import torch
from loguru import logger

from zsceval.algorithms.population.policy_pool import PolicyPool
from zsceval.algorithms.population.trainer_pool import TrainerPool
from zsceval.algorithms.population.utils import _t2n
from zsceval.algorithms.r_mappo.r_mappo import R_MAPPO
from zsceval.runner.shared.base_runner import make_trainer_policy_cls
from zsceval.utils.shared_buffer import SharedReplayBuffer
from zsceval.utils.util import get_shape_from_obs_space


class Traj_Trainer(TrainerPool):
Expand Down
3 changes: 1 addition & 2 deletions zsceval/algorithms/population/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import torch
from loguru import logger


def _t2n(x):
Expand Down Expand Up @@ -68,4 +67,4 @@ def to(self, device):
self.policy.to(device)

def prep_rollout(self):
self.policy.prep_rollout()
self.policy.prep_rollout()
2 changes: 0 additions & 2 deletions zsceval/algorithms/r_mappo/algorithm/rMAPPOPolicy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from collections import OrderedDict

import numpy as np
import torch
from loguru import logger

Expand Down
5 changes: 0 additions & 5 deletions zsceval/algorithms/r_mappo/algorithm/rMAPPOPolicy_epsilon.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from collections import OrderedDict

import numpy as np
import torch
from loguru import logger

from zsceval.algorithms.r_mappo.algorithm.r_actor_critic import R_Actor, R_Critic
from zsceval.algorithms.r_mappo.algorithm.rMAPPOPolicy import R_MAPPOPolicy
from zsceval.utils.util import update_linear_schedule


class R_MAPPOPolicy_Epsilon(R_MAPPOPolicy):
Expand Down
2 changes: 0 additions & 2 deletions zsceval/algorithms/r_mappo/algorithm/r_actor_critic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger

from zsceval.algorithms.utils.act import ACTLayer
Expand Down
9 changes: 2 additions & 7 deletions zsceval/algorithms/r_mappo/r_mappo.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import math
import time
from collections import defaultdict
from pprint import pformat

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from loguru import logger

from zsceval.algorithms.utils.util import check
Expand Down Expand Up @@ -264,9 +259,9 @@ def ppo_update(

def update_actor(self):
if self._use_max_grad_norm:
actor_grad_norm = nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm)
nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm)
else:
actor_grad_norm = get_gard_norm(self.policy.actor.parameters())
get_gard_norm(self.policy.actor.parameters())

self.policy.actor_optimizer.step()

Expand Down
11 changes: 0 additions & 11 deletions zsceval/algorithms/r_mappo/r_mappo_target.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,9 @@
import math
import time
from collections import defaultdict
from pprint import pformat

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from loguru import logger

from zsceval.algorithms.r_mappo.algorithm.rMAPPOPolicy import R_MAPPOPolicy
from zsceval.algorithms.r_mappo.r_mappo import R_MAPPO
from zsceval.algorithms.utils.util import check
from zsceval.utils.util import get_gard_norm, huber_loss, mse_loss
from zsceval.utils.valuenorm import ValueNorm


class R_MAPPO_Target(R_MAPPO):
Expand Down
3 changes: 0 additions & 3 deletions zsceval/algorithms/utils/act.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .distributions import Bernoulli, Categorical, DiagGaussian

Expand Down
1 change: 0 additions & 1 deletion zsceval/algorithms/utils/attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down
1 change: 0 additions & 1 deletion zsceval/algorithms/utils/cnn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .util import init

Expand Down
3 changes: 0 additions & 3 deletions zsceval/algorithms/utils/cnn_simple.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .util import init

Expand Down
2 changes: 0 additions & 2 deletions zsceval/algorithms/utils/distributions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger

from .util import init
Expand Down
2 changes: 0 additions & 2 deletions zsceval/algorithms/utils/mix.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

from .resnet import MapNet, Pre_MapNet
from .util import init


Expand Down
3 changes: 0 additions & 3 deletions zsceval/algorithms/utils/mlp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger

from .attention import Encoder
from .util import get_clones, init
Expand Down
3 changes: 0 additions & 3 deletions zsceval/algorithms/utils/rnn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .util import get_clones, init


class RNNLayer(nn.Module):
Expand Down
2 changes: 0 additions & 2 deletions zsceval/algorithms/utils/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import copy
import glob
import os

import numpy as np
import torch
Expand Down
1 change: 0 additions & 1 deletion zsceval/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import socket

from absl import flags

Expand Down
9 changes: 1 addition & 8 deletions zsceval/envs/env_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@
"""

import multiprocessing as mp
import os
from abc import ABC, abstractmethod
from multiprocessing import Pipe, Process
from typing import Callable, List, Tuple, Union
from typing import List, Tuple, Union

import cloudpickle
import numpy as np
import psutil
import torch
from loguru import logger

from zsceval.utils.util import tile_images

Expand Down Expand Up @@ -61,7 +58,6 @@ def reset(self):
be cancelled and step_wait() should not be called
until step_async() is invoked again.
"""
pass

@abstractmethod
def step_async(self, actions):
Expand All @@ -73,7 +69,6 @@ def step_async(self, actions):
You should not call this if a step_async run is
already pending.
"""
pass

@abstractmethod
def step_wait(self):
Expand All @@ -87,14 +82,12 @@ def step_wait(self):
- dones: an array of "episode done" booleans
- infos: a sequence of info objects
"""
pass

def close_extras(self):
"""
Clean up the extra resources, beyond what's in this base class.
Only runs when not self.closed.
"""
pass

def close(self):
if self.closed:
Expand Down
4 changes: 0 additions & 4 deletions zsceval/envs/grf/grf_env.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import random
from os import stat
from pathlib import Path
from typing import Dict, List, Tuple, Union

import gfootball.env as football_env
import numpy as np
from gym import spaces
from loguru import logger

SHAPED_INFOS = [
"pass",
Expand Down
4 changes: 2 additions & 2 deletions zsceval/envs/grf/raw_feature_process.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Tuple

import numpy as np
from gym.spaces import Box
Expand Down Expand Up @@ -273,7 +273,7 @@ def get_available_actions(self, obs_dict: Dict) -> np.ndarray:
) = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)

if self.action_n == 20:
BUILTIN_AI = 19
pass

# if obs_dict["ball_owned_team"] == 1: # opponents owning ball
# (
Expand Down
5 changes: 1 addition & 4 deletions zsceval/envs/grf/reward_process.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from typing import Dict, List
from typing import Dict

import numpy as np
import torch
from loguru import logger


class Rewarder:
Expand Down
3 changes: 1 addition & 2 deletions zsceval/envs/grf/stats_process.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Dict, List, Tuple, Union

import numpy as np
from loguru import logger

from .grf_env import SHAPED_INFOS

Expand Down Expand Up @@ -35,7 +34,7 @@ def observe(

next_ball_own_team = next_obs_dict_list[0]["ball_owned_team"]
next_ball_own_player = next_obs_dict_list[0]["ball_owned_player"]
next_game_mode = next_obs_dict_list[0]["game_mode"]
next_obs_dict_list[0]["game_mode"]
next_my_score, next_opp_score = next_obs_dict_list[0]["score"]

if ball_own_team != BALL_NO_OWNER:
Expand Down
Loading

0 comments on commit 4b55300

Please sign in to comment.