forked from KevinHuang8/DATT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
122 lines (87 loc) · 3.92 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import sys
import numpy as np
from DATT.quadsim.sim import QuadSim
from DATT.quadsim.models import IdentityModel
from DATT.refs.pointed_star import NPointedStar
from DATT.learning.configs import *
from DATT.controllers import cntrl_config_presets, ControllersZoo
from DATT.configuration.configuration import AllConfig
from DATT.refs import TrajectoryRef
from DATT.python_utils.plotu import subplot, set_3daxes_equal
from DATT.controllers.hybrid_controller import HybridController
import matplotlib.pyplot as plt
from pathlib import Path
if __name__ == "__main__":
import argparse
import time
parser = argparse.ArgumentParser()
parser.add_argument('--cntrl_config', default='datt_hover_config', type=str,
help='Pick or Make a config preset from DATT/quadsim/controllers/cntrl_config_presets')
parser.add_argument('--cntrl', default=ControllersZoo.DATT, type=ControllersZoo)
parser.add_argument('--env_config', default='datt_hover.py')
parser.add_argument('-r', '--ref', dest='ref', type=TrajectoryRef, default=TrajectoryRef.LINE_REF)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
config : AllConfig = import_config(args.env_config)
dt = config.sim_config.dt()
vis = True
plot = True
t_end = 10.0
# Loading refs
seed = args.seed
ref = args.ref.ref(config.ref_config,
seed=seed,
env_diff_seed=config.training_config.env_diff_seed)
# Loading drone configs
model = IdentityModel()
# Loading sim
quadsim = QuadSim(model, vis=vis)
# Loading controller
cntrl : ControllersZoo = args.cntrl
if cntrl == ControllersZoo.HYBRID:
mppi_config = getattr(cntrl_config_presets, "MPPIConfig", "Config not found")
datt_config = getattr(cntrl_config_presets, "DATTConfig", "Config not found")
# Instantiate the HybridController with both configurations
controller = HybridController(config, mppi_config, datt_config)
else:
cntrl_config = getattr(cntrl_config_presets, args.cntrl_config, "Config not found")
controller = cntrl.cntrl(config, {cntrl._value_ : cntrl_config})
controller.ref_func = ref
dists = [
# ConstantForce(np.array([4, 4, 4]))
# WindField(pos=np.array((-1, 1.5, 0.0)), direction=np.array((1, 0, 0)), noisevar=25.0, vmax=1500.0, decay_long=1.8)
]
ts = quadsim.simulate(dt=dt, t_end=t_end, controller=controller, dists=dists)
if not plot:
sys.exit(0)
eulers = np.array([rot.as_euler('ZYX')[::-1] for rot in ts.rot])
plt.figure()
ax = plt.subplot(3, 1, 1)
plt.plot(ts.times, ts.pos[:, 0])
plt.plot(ts.times, ref.pos(ts.times)[0, :])
plt.subplot(3, 1, 2, sharex=ax)
plt.plot(ts.times, ts.pos[:, 1])
plt.plot(ts.times, ref.pos(ts.times)[1, :])
plt.subplot(3, 1, 3, sharex=ax)
plt.plot(ts.times, ts.pos[:, 2])
plt.plot(ts.times, ref.pos(ts.times)[2, :])
plt.suptitle(type(controller).__name__)
plt.figure()
plt.plot(ts.pos[:, 0], ts.pos[:, 1], label='actual')
# plt.plot(ref.pos(ts.times)[0, :], ref.pos(ts.times)[1, :], label='desired')
plt.legend()
# subplot(ts.times, ts.pos, yname="Pos. (m)", title="Position", des=ref.pos(ts.times))
subplot(ts.times, ts.vel, yname="Vel. (m)", title="Velocity")
subplot(ts.times, ref.vel(ts.times).T, yname="Vel. (m)", title="Velocity", label="Desired")
subplot(ts.times, eulers, yname="Euler (rad)", title="ZYX Euler Angles")
subplot(ts.times, ts.ang, yname="$\\omega$ (rad/s)", title="Angular Velocity")
subplot(ts.times, ts.force, yname="Force (N)", title="Body Z Thrust")
# fig = plt.figure(num="Trajectory")
# ax = fig.add_subplot(111, projection='3d')
# plt.plot(ts.pos[:, 0], ts.pos[:, 1], ts.pos[:, 2])
# plt.xlabel("X (m)")
# plt.ylabel("Y (m)")
# ax.set_zlabel("Z (m)")
# plt.title("Trajectory")
# set_3daxes_equal(ax)
plt.show()