Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include SA updates made in-workflow when continuing as new #797

Merged
merged 3 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions core/src/worker/workflow/driven_workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,21 @@ use crate::{
worker::workflow::{OutgoingJob, WFCommand, WorkflowStartedInfo},
};
use prost_types::Timestamp;
use std::sync::mpsc::{self, Receiver, Sender};
use std::{
collections::HashMap,
sync::mpsc::{self, Receiver, Sender},
};
use temporal_sdk_core_protos::{
coresdk::workflow_activation::{start_workflow_from_attribs, WorkflowActivationJob},
temporal::api::history::v1::WorkflowExecutionStartedEventAttributes,
temporal::api::{common::v1::Payload, history::v1::WorkflowExecutionStartedEventAttributes},
utilities::TryIntoOrNone,
};

/// Represents a connection to a lang side workflow that can have activations fed into it and
/// command responses pulled out.
pub(crate) struct DrivenWorkflow {
started_attrs: Option<WorkflowStartedInfo>,
search_attribute_modifications: HashMap<String, Payload>,
incoming_commands: Receiver<Vec<WFCommand>>,
/// Outgoing activation jobs that need to be sent to the lang sdk
outgoing_wf_activation_jobs: Vec<OutgoingJob>,
Expand All @@ -25,6 +29,7 @@ impl DrivenWorkflow {
(
Self {
started_attrs: None,
search_attribute_modifications: Default::default(),
incoming_commands: rx,
outgoing_wf_activation_jobs: vec![],
},
Expand Down Expand Up @@ -85,4 +90,25 @@ impl DrivenWorkflow {
debug!(in_cmds = %in_cmds.display(), "wf bridge iteration fetch");
in_cmds
}

/// Lang sent us an SA upsert command - use it to update our current view of search attributes.
pub(crate) fn search_attributes_update(&mut self, update: HashMap<String, Payload>) {
self.search_attribute_modifications.extend(update);
}

/// Return a view of the "current" state of search attributes. IE: The initial attributes
/// plus any changes during the lifetime of the workflow.
pub(crate) fn get_current_search_attributes(&self) -> HashMap<String, Payload> {
let mut retme = self
.started_attrs
.as_ref()
.and_then(|si| si.search_attrs.as_ref().map(|sa| sa.indexed_fields.clone()))
.unwrap_or_default();
retme.extend(
self.search_attribute_modifications
.iter()
.map(|(a, b)| (a.clone(), b.clone())),
);
retme
}
}
15 changes: 8 additions & 7 deletions core/src/worker/workflow/machines/workflow_machines.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,9 @@ impl WorkflowMachines {
);
}
ProtoCmdAttrs::UpsertWorkflowSearchAttributesCommandAttributes(attrs) => {
// We explicitly do not update the workflows current SAs here since
// core-generated upserts aren't meant to be modified or used within
// workflows by users (but rather, just for them to search with).
self.add_cmd_to_wf_task(
upsert_search_attrs_internal(attrs),
CommandIdKind::NeverResolves,
Expand Down Expand Up @@ -1262,6 +1265,8 @@ impl WorkflowMachines {
self.add_cmd_to_wf_task(new_timer(attrs), CommandID::Timer(seq).into());
}
WFCommand::UpsertSearchAttributes(attrs) => {
self.drive_me
.search_attributes_update(attrs.search_attributes.clone());
self.add_cmd_to_wf_task(
upsert_search_attrs(
attrs,
Expand Down Expand Up @@ -1538,17 +1543,13 @@ impl WorkflowMachines {
.map(Into::into)
.unwrap_or_default();
}
if attrs.search_attributes.is_empty() {
attrs.search_attributes = started_info
.search_attrs
.clone()
.map(Into::into)
.unwrap_or_default();
}
if attrs.retry_policy.is_none() {
attrs.retry_policy.clone_from(&started_info.retry_policy);
}
}
if attrs.search_attributes.is_empty() {
attrs.search_attributes = self.drive_me.get_current_search_attributes();
}
attrs
}

Expand Down
11 changes: 9 additions & 2 deletions sdk/src/workflow_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ use crate::{
};
use crossbeam_channel::{Receiver, Sender};
use futures::{task::Context, FutureExt, Stream, StreamExt};
use parking_lot::RwLock;
use parking_lot::{RwLock, RwLockReadGuard};
use std::{
collections::HashMap,
future::Future,
marker::PhantomData,
ops::Deref,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Expand All @@ -40,7 +41,7 @@ use temporal_sdk_core_protos::{
SignalExternalWorkflowExecution, StartTimer, UpsertWorkflowSearchAttributes,
},
},
temporal::api::common::v1::{Memo, Payload},
temporal::api::common::v1::{Memo, Payload, SearchAttributes},
};
use tokio::sync::{mpsc, oneshot, watch};
use tokio_stream::wrappers::UnboundedReceiverStream;
Expand Down Expand Up @@ -123,6 +124,11 @@ impl WfContext {
self.shared.read().current_build_id.clone()
}

/// Return current values for workflow search attributes
pub fn search_attributes(&self) -> impl Deref<Target = SearchAttributes> + '_ {
RwLockReadGuard::map(self.shared.read(), |s| &s.search_attributes)
}

/// A future that resolves if/when the workflow is cancelled
pub async fn cancelled(&self) {
if *self.am_cancelled.borrow() {
Expand Down Expand Up @@ -412,6 +418,7 @@ pub(crate) struct WfContextSharedData {
pub(crate) wf_time: Option<SystemTime>,
pub(crate) history_length: u32,
pub(crate) current_build_id: Option<String>,
pub(crate) search_attributes: SearchAttributes,
}

/// Helper Wrapper that can drain the channel into a Vec<SignalData> in a blocking way. Useful
Expand Down
17 changes: 15 additions & 2 deletions sdk/src/workflow_future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ impl WorkflowFuture {
if let Some(v) = variant {
match v {
Variant::StartWorkflow(_) => {
// TODO: Can assign randomness seed whenever needed
// Don't do anything in here. Start workflow is looked at earlier, before
// jobs are handled, and may have information taken out of it to avoid clones.
}
Variant::FireTimer(FireTimer { seq }) => {
self.unblock(UnblockEvent::Timer(seq, TimerResult::Fired))?
Expand Down Expand Up @@ -311,7 +312,7 @@ impl Future for WorkflowFuture {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
'activations: loop {
// WF must always receive an activation first before responding with commands
let activation = match self.incoming_activations.poll_recv(cx) {
let mut activation = match self.incoming_activations.poll_recv(cx) {
Poll::Ready(a) => match a {
Some(act) => act,
None => {
Expand Down Expand Up @@ -339,6 +340,18 @@ impl Future for WorkflowFuture {

let mut die_of_eviction_when_done = false;
let mut activation_cmds = vec![];
// Assign initial state from start workflow job
if let Some(start_info) = activation.jobs.iter_mut().find_map(|j| {
if let Some(Variant::StartWorkflow(s)) = j.variant.as_mut() {
Some(s)
} else {
None
}
}) {
// TODO: Can assign randomness seed whenever needed
self.wf_ctx.shared.write().search_attributes =
start_info.search_attributes.take().unwrap_or_default();
};
// Lame hack to avoid hitting "unregistered" update handlers in a situation where
// the history has no commands until an update is accepted. Will go away w/ SDK redesign
if activation
Expand Down
47 changes: 35 additions & 12 deletions tests/integ_tests/workflow_tests/upsert_search_attrs.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::{collections::HashMap, env};
use temporal_client::{WorkflowClientTrait, WorkflowOptions};
use temporal_sdk::{WfContext, WorkflowResult};
use assert_matches::assert_matches;
use std::{collections::HashMap, env, time::Duration};
use temporal_client::{
GetWorkflowResultOpts, WfClientExt, WorkflowClientTrait, WorkflowExecutionResult,
WorkflowOptions,
};
use temporal_sdk::{WfContext, WfExitValue, WorkflowResult};
use temporal_sdk_core_protos::coresdk::{AsJsonPayloadExt, FromJsonPayloadExt};
use temporal_sdk_core_test_utils::{CoreWfStarter, INTEG_TEMPORAL_DEV_SERVER_USED_ENV_VAR};
use tracing::warn;
Expand All @@ -12,11 +16,24 @@ static TXT_ATTR: &str = "CustomTextField";
static INT_ATTR: &str = "CustomIntField";

async fn search_attr_updater(ctx: WfContext) -> WorkflowResult<()> {
let mut int_val = ctx
.search_attributes()
.indexed_fields
.get(INT_ATTR)
.cloned()
.unwrap_or_default();
let orig_val = int_val.data[0];
int_val.data[0] += 1;
ctx.upsert_search_attributes([
(TXT_ATTR.to_string(), "goodbye".as_json_payload().unwrap()),
(INT_ATTR.to_string(), 98.as_json_payload().unwrap()),
(TXT_ATTR.to_string(), "goodbye".as_json_payload()?),
(INT_ATTR.to_string(), int_val),
]);
Ok(().into())
// 49 is ascii 1
if orig_val == 49 {
Ok(WfExitValue::ContinueAsNew(Box::default()))
} else {
Ok(().into())
}
}

#[tokio::test]
Expand All @@ -33,7 +50,7 @@ async fn sends_upsert() {
}

worker.register_wf(wf_name, search_attr_updater);
let run_id = worker
worker
.submit_wf(
wf_id.to_string(),
wf_name,
Expand All @@ -43,17 +60,17 @@ async fn sends_upsert() {
(TXT_ATTR.to_string(), "hello".as_json_payload().unwrap()),
(INT_ATTR.to_string(), 1.as_json_payload().unwrap()),
])),
execution_timeout: Some(Duration::from_secs(4)),
..Default::default()
},
)
.await
.unwrap();
worker.run_until_done().await.unwrap();

let search_attrs = starter
.get_client()
.await
.describe_workflow_execution(wf_id.to_string(), Some(run_id))
let client = starter.get_client().await;
let search_attrs = client
.describe_workflow_execution(wf_id.to_string(), None)
.await
.unwrap()
.workflow_execution_info
Expand All @@ -70,5 +87,11 @@ async fn sends_upsert() {
"goodbye",
String::from_json_payload(txt_attr_payload).unwrap()
);
assert_eq!(98, usize::from_json_payload(int_attr_payload).unwrap());
assert_eq!(3, usize::from_json_payload(int_attr_payload).unwrap());
let handle = client.get_untyped_workflow_handle(wf_id.to_string(), "");
let res = handle
.get_workflow_result(GetWorkflowResultOpts::default())
.await
.unwrap();
assert_matches!(res, WorkflowExecutionResult::Succeeded(_));
}
Loading