diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index 41ce716ff63b68..03b5934ea38dc1 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -347,17 +347,14 @@ private IEnumerator>> GetEnumeratorCore // during enumeration so that we can parse the raw value in order to a) return // the correct set of parsed values, and b) update the instance for subsequent enumerations // to reflect that parsing. - info = new HeaderStoreItemInfo() { RawValue = entry.Value }; - if (EntriesAreLiveView) - { - entries[i].Value = info; - } - else - { - Debug.Assert(Contains(entry.Key)); - ((Dictionary)_headerStore!)[entry.Key] = info; - } +#nullable disable // https://github.com/dotnet/roslyn/issues/73928 + ref object storeValueRef = ref EntriesAreLiveView + ? ref entries[i].Value + : ref CollectionsMarshal.GetValueRefOrNullRef((Dictionary)_headerStore, entry.Key); + + info = ReplaceWithHeaderStoreItemInfo(ref storeValueRef, entry.Value); +#nullable restore } // Make sure we parse all raw values before returning the result. Note that this has to be @@ -729,15 +726,10 @@ private bool TryGetAndParseHeaderInfo(HeaderDescriptor key, [NotNullWhen(true)] if (!Unsafe.IsNullRef(ref storeValueRef)) { object value = storeValueRef; - if (value is HeaderStoreItemInfo hsi) - { - info = hsi; - } - else - { - Debug.Assert(value is string); - storeValueRef = info = new HeaderStoreItemInfo() { RawValue = value }; - } + + info = value is HeaderStoreItemInfo hsi + ? hsi + : ReplaceWithHeaderStoreItemInfo(ref storeValueRef, value); ParseRawHeaderValues(key, info); return true; @@ -747,6 +739,31 @@ private bool TryGetAndParseHeaderInfo(HeaderDescriptor key, [NotNullWhen(true)] return false; } + /// + /// Replaces with a new , + /// or returns the existing if a different thread beat us to it. + /// + /// + /// This helper should be used any time we're upgrading a storage slot from an unparsed string to a HeaderStoreItemInfo *while reading*. + /// Concurrent writes to the header collection are UB, so we don't need to worry about race conditions when doing the replacement there. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static HeaderStoreItemInfo ReplaceWithHeaderStoreItemInfo(ref object storeValueRef, object value) + { + Debug.Assert(value is string); + + var info = new HeaderStoreItemInfo() { RawValue = value }; + object previousValue = Interlocked.CompareExchange(ref storeValueRef, info, value); + + if (ReferenceEquals(previousValue, value)) + { + return info; + } + + // Rare race condition: Another thread replaced the value with a HeaderStoreItemInfo. + return (HeaderStoreItemInfo)previousValue; + } + private static void ParseRawHeaderValues(HeaderDescriptor descriptor, HeaderStoreItemInfo info) { // Unlike TryGetHeaderInfo() this method tries to parse all non-validated header values (if any) diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs index e266c60fb134de..ce306a2e4cd48e 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs @@ -3,6 +3,7 @@ using System.Collections; using System.Collections.Generic; +using System.Diagnostics; using System.Globalization; using System.Linq; using System.Net.Http.Headers; @@ -2502,6 +2503,51 @@ static HttpRequestHeaders CreateHeaders() } } + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task ConcurrentReads_ReturnTheSameParsedValues(bool useDictionary, bool useTypedProperty) + { + HttpContentHeaders dummyValues = new ByteArrayContent([]).Headers; + if (useDictionary) + { + for (int i = 0; i < HttpHeaders.ArrayThreshold; i++) + { + Assert.True(dummyValues.TryAddWithoutValidation($"foo-{i}", "Foo")); + } + } + + Stopwatch s = Stopwatch.StartNew(); + + while (s.ElapsedMilliseconds < 100) + { + HttpContentHeaders headers = new ByteArrayContent([]).Headers; + + headers.AddHeaders(dummyValues); + + Assert.True(headers.TryAddWithoutValidation("Content-Type", "application/json; charset=utf-8")); + + if (useTypedProperty) + { + Task task = Task.Run(() => headers.ContentType); + MediaTypeHeaderValue contentType1 = headers.ContentType; + MediaTypeHeaderValue contentType2 = await task; + + Assert.Same(contentType1, contentType2); + } + else + { + Task task = Task.Run(() => headers.Count()); // Force enumeration + MediaTypeHeaderValue contentType1 = headers.ContentType; + await task; + + Assert.Same(contentType1, headers.ContentType); + } + } + } + [Fact] public void TryAddInvalidHeader_ShouldThrowFormatException() {