ChoreoLib
Choreo support library.
Loading...
Searching...
No Matches
Trajectory.h
1// Copyright (c) Choreo contributors
2
3#pragma once
4
5#include <algorithm>
6#include <optional>
7#include <ranges>
8#include <string>
9#include <string_view>
10#include <utility>
11#include <vector>
12
13#include <units/time.h>
14#include <wpi/json_fwd.h>
15
16#include "choreo/trajectory/DifferentialSample.h"
17#include "choreo/trajectory/EventMarker.h"
18#include "choreo/trajectory/SwerveSample.h"
19#include "choreo/trajectory/TrajectorySample.h"
20
21namespace choreo {
22
28template <TrajectorySample SampleType>
30 public:
34 Trajectory() = default;
35
44 Trajectory(std::string_view name, std::vector<SampleType> samples,
45 std::vector<int> splits, std::vector<EventMarker> events)
46 : name{name},
47 samples{std::move(samples)},
48 splits{std::move(splits)},
49 events{std::move(events)} {}
50
60 std::optional<SampleType> GetInitialSample(
61 bool mirrorForRedAlliance = false) const {
62 if (samples.size() == 0) {
63 return {};
64 }
65 return mirrorForRedAlliance ? samples.front().Flipped() : samples.front();
66 }
67
77 std::optional<SampleType> GetFinalSample(
78 bool mirrorForRedAlliance = false) const {
79 if (samples.size() == 0) {
80 return {};
81 }
82 return mirrorForRedAlliance ? samples.back().Flipped() : samples.back();
83 }
84
96 template <int Year = util::kDefaultYear>
97 std::optional<SampleType> SampleAt(units::second_t timestamp,
98 bool mirrorForRedAlliance = false) const {
99 if (auto state = SampleInternal(timestamp)) {
100 return mirrorForRedAlliance ? state.value().template Flipped<Year>()
101 : state;
102 } else {
103 return {};
104 }
105 }
106
116 template <int Year = util::kDefaultYear>
117 std::optional<frc::Pose2d> GetInitialPose(
118 bool mirrorForRedAlliance = false) const {
119 if (samples.size() == 0) {
120 return {};
121 }
122 if (mirrorForRedAlliance) {
123 return samples.front().template Flipped<Year>().GetPose();
124 } else {
125 return samples.front().GetPose();
126 }
127 }
128
138 template <int Year = util::kDefaultYear>
139 std::optional<frc::Pose2d> GetFinalPose(
140 bool mirrorForRedAlliance = false) const {
141 if (samples.size() == 0) {
142 return {};
143 }
144 if (mirrorForRedAlliance) {
145 return samples.back().template Flipped<Year>().GetPose();
146 } else {
147 return samples.back().GetPose();
148 }
149 }
150
157 units::second_t GetTotalTime() const {
158 if (samples.size() == 0) {
159 return 0_s;
160 }
161 return GetFinalSample().value().GetTimestamp();
162 }
163
169 std::vector<frc::Pose2d> GetPoses() const {
170 std::vector<frc::Pose2d> poses;
171 for (const auto& sample : samples) {
172 poses.push_back(sample.GetPose());
173 }
174 return poses;
175 }
176
183 template <int Year = util::kDefaultYear>
185 std::vector<SampleType> flippedStates;
186 for (const auto& state : samples) {
187 flippedStates.push_back(state.template Flipped<Year>());
188 }
189 return Trajectory<SampleType>(name, flippedStates, splits, events);
190 }
191
199 std::vector<EventMarker> GetEvents(std::string_view eventName) const {
200 std::vector<EventMarker> matchingEvents;
201 for (const auto& event : events) {
202 if (event.event == eventName) {
203 matchingEvents.push_back(event);
204 }
205 }
206 return matchingEvents;
207 }
208
217 std::optional<Trajectory<SampleType>> GetSplit(int splitIndex) const {
218 // Assumption: splits[splitIndex] is a valid index of samples.
219 if (splitIndex < 0 || splitIndex >= splits.size()) {
220 return std::nullopt;
221 }
222
223 int start = splits[splitIndex];
224 int end = (splitIndex + 1 < splits.size()) ? splits[splitIndex + 1] + 1
225 : samples.size();
226
227 auto sublist =
228 std::vector<SampleType>(samples.begin() + start, samples.begin() + end);
229 // Empty section should not be achievable (would mean malformed splits
230 // array), but is handled for safety
231 if (sublist.size() == 0) {
233 name + "[" + std::to_string(splitIndex) + "]", {}, {}, {}};
234 }
235 // Now we know sublist.size() >= 1
236 units::second_t startTime = sublist.front().GetTimestamp();
237 units::second_t endTime = sublist.back().GetTimestamp();
238
239 auto offsetSamples =
240 sublist | std::views::transform([startTime](const SampleType& s) {
241 return s.OffsetBy(-startTime);
242 });
243
244 auto filteredEvents =
245 events | std::views::filter([startTime, endTime](const auto& e) {
246 return e.timestamp >= startTime && e.timestamp <= endTime;
247 }) |
248 std::views::transform(
249 [startTime](const auto& e) { return e.OffsetBy(-startTime); });
250
252 name + "[" + std::to_string(splitIndex) + "]",
253 std::vector<SampleType>(offsetSamples.begin(), offsetSamples.end()),
254 {},
255 std::vector<EventMarker>(filteredEvents.begin(), filteredEvents.end())};
256 }
257
264 bool operator==(const Trajectory<SampleType>& other) const {
265 if (name != other.name) {
266 return false;
267 }
268
269 if (samples.size() != other.samples.size()) {
270 return false;
271 }
272 if (!std::equal(samples.begin(), samples.end(), other.samples.begin())) {
273 return false;
274 }
275
276 if (splits != other.splits) {
277 return false;
278 }
279
280 if (events.size() != other.events.size()) {
281 return false;
282 }
283 if (!std::equal(events.begin(), events.end(), other.events.begin())) {
284 return false;
285 }
286
287 return true;
288 }
289
291 std::string name;
292
294 std::vector<SampleType> samples;
295
297 std::vector<int> splits;
298
300 std::vector<EventMarker> events;
301
302 private:
303 std::optional<SampleType> SampleInternal(units::second_t timestamp) const {
304 if (samples.size() == 0) {
305 return {};
306 }
307 if (samples.size() == 1) {
308 return samples[0];
309 }
310 if (timestamp < samples[0].GetTimestamp()) {
311 return GetInitialSample();
312 }
313 if (timestamp >= GetTotalTime()) {
314 return GetFinalSample();
315 }
316
317 int low = 0;
318 int high = samples.size() - 1;
319
320 while (low != high) {
321 int mid = (low + high) / 2;
322 if (samples[mid].GetTimestamp() < timestamp) {
323 low = mid + 1;
324 } else {
325 high = mid;
326 }
327 }
328
329 if (low == 0) {
330 return samples[low];
331 }
332
333 SampleType behindState = samples[low - 1];
334 SampleType aheadState = samples[low];
335
336 if ((aheadState.GetTimestamp() - behindState.GetTimestamp()) < 1e-6_s) {
337 return aheadState;
338 }
339
340 return behindState.Interpolate(aheadState, timestamp);
341 }
342};
343
344void to_json(wpi::json& json, const Trajectory<SwerveSample>& trajectory);
345void from_json(const wpi::json& json, Trajectory<SwerveSample>& trajectory);
346
347void to_json(wpi::json& json, const Trajectory<DifferentialSample>& trajectory);
348void from_json(const wpi::json& json,
349 Trajectory<DifferentialSample>& trajectory);
350
351} // namespace choreo
Definition Trajectory.h:29
std::vector< frc::Pose2d > GetPoses() const
Definition Trajectory.h:169
std::vector< SampleType > samples
The vector of samples in the trajectory.
Definition Trajectory.h:294
std::optional< frc::Pose2d > GetInitialPose(bool mirrorForRedAlliance=false) const
Definition Trajectory.h:117
std::string name
The name of the trajectory.
Definition Trajectory.h:291
std::vector< int > splits
The waypoints indexes where the trajectory is split.
Definition Trajectory.h:297
Trajectory(std::string_view name, std::vector< SampleType > samples, std::vector< int > splits, std::vector< EventMarker > events)
Definition Trajectory.h:44
std::optional< frc::Pose2d > GetFinalPose(bool mirrorForRedAlliance=false) const
Definition Trajectory.h:139
std::optional< SampleType > SampleAt(units::second_t timestamp, bool mirrorForRedAlliance=false) const
Definition Trajectory.h:97
Trajectory< SampleType > Flipped() const
Definition Trajectory.h:184
std::optional< SampleType > GetInitialSample(bool mirrorForRedAlliance=false) const
Definition Trajectory.h:60
Trajectory()=default
units::second_t GetTotalTime() const
Definition Trajectory.h:157
std::optional< SampleType > GetFinalSample(bool mirrorForRedAlliance=false) const
Definition Trajectory.h:77
bool operator==(const Trajectory< SampleType > &other) const
Definition Trajectory.h:264
std::optional< Trajectory< SampleType > > GetSplit(int splitIndex) const
Definition Trajectory.h:217
std::vector< EventMarker > GetEvents(std::string_view eventName) const
Definition Trajectory.h:199
std::vector< EventMarker > events
A vector of all of the events in the trajectory.
Definition Trajectory.h:300