diff --git a/data.go b/data.go index 47cd619..e09092d 100644 --- a/data.go +++ b/data.go @@ -3,6 +3,7 @@ package astits import ( "encoding/binary" "fmt" + "github.com/asticode/go-astikit" ) @@ -35,7 +36,7 @@ type MuxerData struct { } // parseData parses a payload spanning over multiple packets and returns a set of data -func parseData(ps []*Packet, prs PacketsParser, pm *programMap) (ds []*DemuxerData, err error) { +func parseData(ps []*Packet, prs PacketsParser, pm *programMap, esm *elementaryStreamMap) (ds []*DemuxerData, err error) { // Use custom parser first if prs != nil { var skip bool @@ -79,7 +80,7 @@ func parseData(ps []*Packet, prs PacketsParser, pm *programMap) (ds []*DemuxerDa if pid == PIDCAT { // Information in a CAT payload is private and dependent on the CA system. Use the PacketsParser // to parse this type of payload - } else if isPSIPayload(pid, pm) { + } else if isPSIPayload(pid, pm, esm) { // Parse PSI data var psiData *PSIData if psiData, err = parsePSIData(i); err != nil { @@ -110,10 +111,11 @@ func parseData(ps []*Packet, prs PacketsParser, pm *programMap) (ds []*DemuxerDa } // isPSIPayload checks whether the payload is a PSI one -func isPSIPayload(pid uint16, pm *programMap) bool { +func isPSIPayload(pid uint16, pm *programMap, esm *elementaryStreamMap) bool { return pid == PIDPAT || // PAT pm.existsUnlocked(pid) || // PMT - ((pid >= 0x10 && pid <= 0x14) || (pid >= 0x1e && pid <= 0x1f)) //DVB + (((pid >= 0x10 && pid <= 0x14) || (pid >= 0x1e && pid <= 0x1f)) && //DVB + !esm.existsLocked(pid)) // for non-DVB } // isPESPayload checks whether the payload is a PES one diff --git a/data_test.go b/data_test.go index 2a8da98..2f9550f 100644 --- a/data_test.go +++ b/data_test.go @@ -11,6 +11,7 @@ import ( func TestParseData(t *testing.T) { // Init pm := newProgramMap() + esm := newElementaryStreamMap() ps := []*Packet{} // Custom parser @@ -20,13 +21,13 @@ func TestParseData(t *testing.T) { skip = true return } - ds, err := parseData(ps, c, pm) + ds, err := parseData(ps, c, pm, esm) assert.NoError(t, err) assert.Equal(t, cds, ds) // Do nothing for CAT ps = []*Packet{{Header: PacketHeader{PID: PIDCAT}}} - ds, err = parseData(ps, nil, pm) + ds, err = parseData(ps, nil, pm, esm) assert.NoError(t, err) assert.Empty(t, ds) @@ -42,7 +43,7 @@ func TestParseData(t *testing.T) { Payload: p[33:], }, } - ds, err = parseData(ps, nil, pm) + ds, err = parseData(ps, nil, pm, esm) assert.NoError(t, err) assert.Equal(t, []*DemuxerData{ { @@ -64,7 +65,7 @@ func TestParseData(t *testing.T) { Payload: p[33:], }, } - ds, err = parseData(ps, nil, pm) + ds, err = parseData(ps, nil, pm, esm) assert.NoError(t, err) assert.Equal(t, psi.toData( &Packet{Header: ps[0].Header, AdaptationField: ps[0].AdaptationField}, @@ -74,15 +75,16 @@ func TestParseData(t *testing.T) { func TestIsPSIPayload(t *testing.T) { pm := newProgramMap() + esm := newElementaryStreamMap() var pids []int for i := 0; i <= 255; i++ { - if isPSIPayload(uint16(i), pm) { + if isPSIPayload(uint16(i), pm, esm) { pids = append(pids, i) } } assert.Equal(t, []int{0, 16, 17, 18, 19, 20, 30, 31}, pids) pm.setUnlocked(uint16(1), uint16(0)) - assert.True(t, isPSIPayload(uint16(1), pm)) + assert.True(t, isPSIPayload(uint16(1), pm, esm)) } func TestIsPESPayload(t *testing.T) { diff --git a/demuxer.go b/demuxer.go index 044856a..6fc2afa 100644 --- a/demuxer.go +++ b/demuxer.go @@ -34,6 +34,7 @@ type Demuxer struct { packetBuffer *packetBuffer packetPool *packetPool programMap *programMap + streamMap *elementaryStreamMap r io.Reader } @@ -52,6 +53,7 @@ func NewDemuxer(ctx context.Context, r io.Reader, opts ...func(*Demuxer)) (d *De ctx: ctx, l: astikit.AdaptStdLogger(nil), programMap: newProgramMap(), + streamMap: newElementaryStreamMap(), r: r, } d.packetPool = newPacketPool(d.programMap) @@ -145,7 +147,7 @@ func (dmx *Demuxer) NextData() (d *DemuxerData, err error) { // Parse data var errParseData error - if ds, errParseData = parseData(ps, dmx.optPacketsParser, dmx.programMap); errParseData != nil { + if ds, errParseData = parseData(ps, dmx.optPacketsParser, dmx.programMap, dmx.streamMap); errParseData != nil { // Log error as there may be some incomplete data here // We still want to try to parse all packets, in case final data is complete dmx.l.Error(fmt.Errorf("astits: parsing data failed: %w", errParseData)) @@ -170,7 +172,7 @@ func (dmx *Demuxer) NextData() (d *DemuxerData, err error) { } // Parse data - if ds, err = parseData(ps, dmx.optPacketsParser, dmx.programMap); err != nil { + if ds, err = parseData(ps, dmx.optPacketsParser, dmx.programMap, dmx.streamMap); err != nil { err = fmt.Errorf("astits: building new data failed: %w", err) return } @@ -199,6 +201,11 @@ func (dmx *Demuxer) updateData(ds []*DemuxerData) (d *DemuxerData) { } } } + if v.PMT != nil { + for _, es := range v.PMT.ElementaryStreams { + dmx.streamMap.setLocked(es.ElementaryPID, v.PMT.ProgramNumber) + } + } } } return diff --git a/stream_map.go b/stream_map.go new file mode 100644 index 0000000..8f22264 --- /dev/null +++ b/stream_map.go @@ -0,0 +1,29 @@ +package astits + +// elementaryStreamMap represents an elementary stream ids map +type elementaryStreamMap struct { + // We use map[uint32] instead map[uint16] as go runtime provide optimized hash functions for (u)int32/64 keys + es map[uint32]uint16 // map[StreamID]ProgramNumber +} + +// newElementaryStreamMap creates a new elementary stream ids map +func newElementaryStreamMap() *elementaryStreamMap { + return &elementaryStreamMap{ + es: make(map[uint32]uint16), + } +} + +// setLocked sets a new stream id to the elementary stream +func (m elementaryStreamMap) setLocked(pid, number uint16) { + m.es[uint32(pid)] = number +} + +// existsLocked checks whether the stream with this pid exists +func (m elementaryStreamMap) existsLocked(pid uint16) (ok bool) { + _, ok = m.es[uint32(pid)] + return +} + +func (m elementaryStreamMap) unsetLocked(pid uint16) { + delete(m.es, uint32(pid)) +} diff --git a/stream_map_test.go b/stream_map_test.go new file mode 100644 index 0000000..21ba97d --- /dev/null +++ b/stream_map_test.go @@ -0,0 +1,16 @@ +package astits + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestElementaryStreamMap(t *testing.T) { + esm := newElementaryStreamMap() + assert.False(t, esm.existsLocked(0x16)) + esm.setLocked(0x16, 1) + assert.True(t, esm.existsLocked(0x16)) + esm.unsetLocked(0x16) + assert.False(t, esm.existsLocked(0x16)) +}