update(_:with:)
Updates a model with a new batch of examples.
Declaration
func update(_ model: inout LinearTimeSeriesForecaster<Scalar>.Transformer, with input: AnnotatedBatch<Scalar>) async throws -> ScalarParameters
- model:
The model to update.
- input:
A shaped array of windowed features. The shape should be
[batchSize, inputWindowSize, featureSize].
Discussion
Use TimeSeriesForecasterBatches to convert a shaped array of features into batches of windowed features and annotations. Here is an example of training a forecaster:
let estimator = LinearTimeSeriesForecaster<Float>(configuration: configuration)
var model = estimator.makeTransformer()
let batches = try TimeSeriesForecasterBatches(
features: features, // shape [N, featureSize]
annotations: annotations, // shape [N, annotationSize]
batchSize: 32,
inputWindowSize: configuration.inputWindowSize,
forecastWindowSize: configuration.forecastWindowSize,
shufflesBatches: true
)
for iteration in 0 ..< configuration.maximumIterationCount {
for batch in batches {
let loss = try await estimator.update(&model, with: batch)
print("Loss: \(loss)")
}
}