ChoreoLib
Choreo support library.
Loading...
Searching...
No Matches
Trajectory.h
1// Copyright (c) Choreo contributors
2
3#pragma once
4
5#include <algorithm>
6#include <optional>
7#include <ranges>
8#include <string>
9#include <string_view>
10#include <utility>
11#include <vector>
12
13#include <units/time.h>
14#include <wpi/json_fwd.h>
15
16#include "choreo/trajectory/DifferentialSample.h"
17#include "choreo/trajectory/EventMarker.h"
18#include "choreo/trajectory/SwerveSample.h"
19#include "choreo/trajectory/TrajectorySample.h"
20
21namespace choreo {
22
26template <TrajectorySample SampleType>
28 public:
30 Trajectory() = default;
31
38 Trajectory(std::string_view name, std::vector<SampleType> samples,
39 std::vector<int> splits, std::vector<EventMarker> events)
40 : name{name},
41 samples{std::move(samples)},
42 splits{std::move(splits)},
43 events{std::move(events)} {}
44
52 std::optional<SampleType> GetInitialSample(
53 bool mirrorForRedAlliance = false) const {
54 if (samples.size() == 0) {
55 return {};
56 }
57 return mirrorForRedAlliance ? samples.front().Flipped() : samples.front();
58 }
59
67 std::optional<SampleType> GetFinalSample(
68 bool mirrorForRedAlliance = false) const {
69 if (samples.size() == 0) {
70 return {};
71 }
72 return mirrorForRedAlliance ? samples.back().Flipped() : samples.back();
73 }
74
84 template <int Year = util::kDefaultYear>
85 std::optional<SampleType> SampleAt(units::second_t timestamp,
86 bool mirrorForRedAlliance = false) const {
87 if (auto state = SampleInternal(timestamp)) {
88 return mirrorForRedAlliance ? state.value().template Flipped<Year>()
89 : state;
90 } else {
91 return {};
92 }
93 }
94
102 template <int Year = util::kDefaultYear>
103 std::optional<frc::Pose2d> GetInitialPose(
104 bool mirrorForRedAlliance = false) const {
105 if (samples.size() == 0) {
106 return {};
107 }
108 if (mirrorForRedAlliance) {
109 return samples.front().template Flipped<Year>().GetPose();
110 } else {
111 return samples.front().GetPose();
112 }
113 }
114
122 template <int Year = util::kDefaultYear>
123 std::optional<frc::Pose2d> GetFinalPose(
124 bool mirrorForRedAlliance = false) const {
125 if (samples.size() == 0) {
126 return {};
127 }
128 if (mirrorForRedAlliance) {
129 return samples.back().template Flipped<Year>().GetPose();
130 } else {
131 return samples.back().GetPose();
132 }
133 }
134
139 units::second_t GetTotalTime() const {
140 if (samples.size() == 0) {
141 return 0_s;
142 }
143 return GetFinalSample().value().GetTimestamp();
144 }
145
149 std::vector<frc::Pose2d> GetPoses() const {
150 std::vector<frc::Pose2d> poses;
151 for (const auto& sample : samples) {
152 poses.push_back(sample.GetPose());
153 }
154 return poses;
155 }
156
161 template <int Year = util::kDefaultYear>
163 std::vector<SampleType> flippedStates;
164 for (const auto& state : samples) {
165 flippedStates.push_back(state.template Flipped<Year>());
166 }
167 return Trajectory<SampleType>(name, flippedStates, splits, events);
168 }
169
175 std::vector<EventMarker> GetEvents(std::string_view eventName) const {
176 std::vector<EventMarker> matchingEvents;
177 for (const auto& event : events) {
178 if (event.event == eventName) {
179 matchingEvents.push_back(event);
180 }
181 }
182 return matchingEvents;
183 }
184
191 std::optional<Trajectory<SampleType>> GetSplit(int splitIndex) const {
192 // Assumption: splits[splitIndex] is a valid index of samples.
193 if (splitIndex < 0 || splitIndex >= splits.size()) {
194 return std::nullopt;
195 }
196
197 int start = splits[splitIndex];
198 int end = (splitIndex + 1 < splits.size()) ? splits[splitIndex + 1] + 1
199 : samples.size();
200
201 auto sublist =
202 std::vector<SampleType>(samples.begin() + start, samples.begin() + end);
203 // Empty section should not be achievable (would mean malformed splits
204 // array), but is handled for safety
205 if (sublist.size() == 0) {
207 name + "[" + std::to_string(splitIndex) + "]", {}, {}, {}};
208 }
209 // Now we know sublist.size() >= 1
210 units::second_t startTime = sublist.front().GetTimestamp();
211 units::second_t endTime = sublist.back().GetTimestamp();
212
213 auto offsetSamples =
214 sublist | std::views::transform([startTime](const SampleType& s) {
215 return s.OffsetBy(-startTime);
216 });
217
218 auto filteredEvents =
219 events | std::views::filter([startTime, endTime](const auto& e) {
220 return e.timestamp >= startTime && e.timestamp <= endTime;
221 }) |
222 std::views::transform(
223 [startTime](const auto& e) { return e.OffsetBy(-startTime); });
224
226 name + "[" + std::to_string(splitIndex) + "]",
227 std::vector<SampleType>(offsetSamples.begin(), offsetSamples.end()),
228 {},
229 std::vector<EventMarker>(filteredEvents.begin(), filteredEvents.end())};
230 }
231
236 bool operator==(const Trajectory<SampleType>& other) const {
237 if (name != other.name) {
238 return false;
239 }
240
241 if (samples.size() != other.samples.size()) {
242 return false;
243 }
244 if (!std::equal(samples.begin(), samples.end(), other.samples.begin())) {
245 return false;
246 }
247
248 if (splits != other.splits) {
249 return false;
250 }
251
252 if (events.size() != other.events.size()) {
253 return false;
254 }
255 if (!std::equal(events.begin(), events.end(), other.events.begin())) {
256 return false;
257 }
258
259 return true;
260 }
261
263 std::string name;
264
266 std::vector<SampleType> samples;
267
269 std::vector<int> splits;
270
272 std::vector<EventMarker> events;
273
274 private:
275 std::optional<SampleType> SampleInternal(units::second_t timestamp) const {
276 if (samples.size() == 0) {
277 return {};
278 }
279 if (samples.size() == 1) {
280 return samples[0];
281 }
282 if (timestamp < samples[0].GetTimestamp()) {
283 return GetInitialSample();
284 }
285 if (timestamp >= GetTotalTime()) {
286 return GetFinalSample();
287 }
288
289 int low = 0;
290 int high = samples.size() - 1;
291
292 while (low != high) {
293 int mid = (low + high) / 2;
294 if (samples[mid].GetTimestamp() < timestamp) {
295 low = mid + 1;
296 } else {
297 high = mid;
298 }
299 }
300
301 if (low == 0) {
302 return samples[low];
303 }
304
305 SampleType behindState = samples[low - 1];
306 SampleType aheadState = samples[low];
307
308 if ((aheadState.GetTimestamp() - behindState.GetTimestamp()) < 1e-6_s) {
309 return aheadState;
310 }
311
312 return behindState.Interpolate(aheadState, timestamp);
313 }
314};
315
316void to_json(wpi::json& json, const Trajectory<SwerveSample>& trajectory);
317void from_json(const wpi::json& json, Trajectory<SwerveSample>& trajectory);
318
319void to_json(wpi::json& json, const Trajectory<DifferentialSample>& trajectory);
320void from_json(const wpi::json& json,
321 Trajectory<DifferentialSample>& trajectory);
322
323} // namespace choreo
Definition Trajectory.h:27
std::vector< frc::Pose2d > GetPoses() const
Definition Trajectory.h:149
std::vector< SampleType > samples
The vector of samples in the trajectory.
Definition Trajectory.h:266
std::optional< frc::Pose2d > GetInitialPose(bool mirrorForRedAlliance=false) const
Definition Trajectory.h:103
std::string name
The name of the trajectory.
Definition Trajectory.h:263
std::vector< int > splits
The waypoints indexes where the trajectory is split.
Definition Trajectory.h:269
Trajectory(std::string_view name, std::vector< SampleType > samples, std::vector< int > splits, std::vector< EventMarker > events)
Definition Trajectory.h:38
std::optional< frc::Pose2d > GetFinalPose(bool mirrorForRedAlliance=false) const
Definition Trajectory.h:123
std::optional< SampleType > SampleAt(units::second_t timestamp, bool mirrorForRedAlliance=false) const
Definition Trajectory.h:85
Trajectory< SampleType > Flipped() const
Definition Trajectory.h:162
std::optional< SampleType > GetInitialSample(bool mirrorForRedAlliance=false) const
Definition Trajectory.h:52
Trajectory()=default
Constructs a Trajectory with defaults.
units::second_t GetTotalTime() const
Definition Trajectory.h:139
std::optional< SampleType > GetFinalSample(bool mirrorForRedAlliance=false) const
Definition Trajectory.h:67
bool operator==(const Trajectory< SampleType > &other) const
Definition Trajectory.h:236
std::optional< Trajectory< SampleType > > GetSplit(int splitIndex) const
Definition Trajectory.h:191
std::vector< EventMarker > GetEvents(std::string_view eventName) const
Definition Trajectory.h:175
std::vector< EventMarker > events
A vector of all of the events in the trajectory.
Definition Trajectory.h:272