diff --git a/lib/mosquitto_internal.h b/lib/mosquitto_internal.h index 94d640362..f3c546231 100644 --- a/lib/mosquitto_internal.h +++ b/lib/mosquitto_internal.h @@ -95,6 +95,7 @@ enum mosquitto_msg_direction { }; enum mosquitto_msg_state { + mosq_ms_any = -1, mosq_ms_invalid = 0, mosq_ms_publish_qos0 = 1, mosq_ms_publish_qos1 = 2, diff --git a/src/database.c b/src/database.c index f9efe0a33..75f4e026a 100644 --- a/src/database.c +++ b/src/database.c @@ -411,7 +411,7 @@ int db__message_delete_outgoing(struct mosquitto *context, uint16_t mid, enum mo if(client_msg->data.mid == mid){ if(client_msg->data.qos != qos){ return MOSQ_ERR_PROTOCOL; - }else if(qos == 2 && client_msg->data.state != expect_state){ + }else if(qos == 2 && client_msg->data.state != expect_state && expect_state != mosq_ms_any){ return MOSQ_ERR_PROTOCOL; } db__message_remove_inflight(context, &context->msgs_out, client_msg); @@ -425,7 +425,7 @@ int db__message_delete_outgoing(struct mosquitto *context, uint16_t mid, enum mo if(client_msg->data.mid == mid){ if(client_msg->data.qos != qos){ return MOSQ_ERR_PROTOCOL; - }else if(qos == 2 && client_msg->data.state != expect_state){ + }else if(qos == 2 && client_msg->data.state != expect_state && expect_state != mosq_ms_any){ return MOSQ_ERR_PROTOCOL; } db__message_remove_queued(context, &context->msgs_out, client_msg); @@ -703,11 +703,12 @@ int db__message_insert_outgoing(struct mosquitto *context, uint64_t cmsg_id, uin return rc; } -int db__message_update_outgoing(struct mosquitto *context, uint16_t mid, enum mosquitto_msg_state state, int qos, bool persist) +static inline int db__message_update_outgoing_state(struct mosquitto *context, struct mosquitto__client_msg *head, + uint16_t mid, enum mosquitto_msg_state state, int qos, bool persist) { struct mosquitto__client_msg *client_msg; - DL_FOREACH(context->msgs_out.inflight, client_msg){ + DL_FOREACH(head, client_msg){ if(client_msg->data.mid == mid){ if(client_msg->data.qos != qos){ return MOSQ_ERR_PROTOCOL; @@ -722,6 +723,17 @@ int db__message_update_outgoing(struct mosquitto *context, uint16_t mid, enum mo return MOSQ_ERR_NOT_FOUND; } +int db__message_update_outgoing(struct mosquitto *context, uint16_t mid, enum mosquitto_msg_state state, int qos, bool persist) +{ + int rc; + + rc = db__message_update_outgoing_state(context, context->msgs_out.inflight, mid, state, qos, persist); + if (!persist && rc == MOSQ_ERR_NOT_FOUND){ + rc = db__message_update_outgoing_state(context, context->msgs_out.queued, mid, state, qos, persist); + } + return rc; +} + static void db__messages_delete_list(struct mosquitto__client_msg **head) { diff --git a/src/plugin_public.c b/src/plugin_public.c index a41991247..3b9ef361e 100644 --- a/src/plugin_public.c +++ b/src/plugin_public.c @@ -658,14 +658,14 @@ BROKER_EXPORT int mosquitto_persist_client_msg_delete(struct mosquitto_client_ms return MOSQ_ERR_NOT_FOUND; } + + int rc = MOSQ_ERR_INVAL; if(client_msg->direction == mosq_md_out){ - return db__message_delete_outgoing(context, client_msg->mid, client_msg->state, client_msg->qos); + rc = db__message_delete_outgoing(context, client_msg->mid, mosq_ms_any, client_msg->qos); }else if(client_msg->direction == mosq_md_in){ - return db__message_remove_incoming(context, client_msg->mid); - }else{ - return MOSQ_ERR_INVAL; + rc = db__message_remove_incoming(context, client_msg->mid); } - return MOSQ_ERR_SUCCESS; + return rc; }